Skip to content

Commit 1eb041e

Browse files
committed
Add a new RecomposeComplexOps pass, fold slice+copy_ into indeX_put_
1 parent 40c25ce commit 1eb041e

File tree

6 files changed

+153
-0
lines changed

6 files changed

+153
-0
lines changed

e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,8 @@
848848
"DropoutTrainModule_basic",
849849
"StdCorrectionKeepDimModule_basic",
850850
"StdCorrectionNoneModule_basic",
851+
"SliceCopy_Module_basic",
852+
"SliceCopyNegative_Module_basic",
851853
"VarBiasedModule_basic",
852854
"VarCorrectionAllDimReduceModule_basic",
853855
"VarCorrectionEmptyDimModule_basic",

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
9898
std::unique_ptr<OperationPass<func::FuncOp>>
9999
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
100100

101+
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
102+
101103
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
102104

103105
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();

lib/Dialect/Torch/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
99
LowerToBackendContract.cpp
1010
MaximizeValueSemantics.cpp
1111
PrepareForGlobalizeObjectGraph.cpp
12+
RecomposeComplexOps.cpp
1213
ReduceOpVariants.cpp
1314
RefinePublicReturn.cpp
1415
RefineTypes.cpp

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
106106
// Clean up again to avoid needing to to back around the fixed-point
107107
// iteration.
108108
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
109+
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
109110
// Reduce variants of ops to a smaller set of primitives.
110111
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
111112
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "PassDetail.h"
11+
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
14+
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::torch;
18+
using namespace mlir::torch::Torch;
19+
20+
namespace {
21+
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
22+
public:
23+
using OpRewritePattern::OpRewritePattern;
24+
LogicalResult matchAndRewrite(AtenCopy_Op op,
25+
PatternRewriter &rewriter) const override {
26+
if (!op.getSelf().getDefiningOp() ||
27+
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
28+
return failure();
29+
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
30+
31+
// Get indices
32+
int64_t dim;
33+
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
34+
return failure();
35+
int64_t end;
36+
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
37+
return failure();
38+
39+
Value newEnd = sliceOp.getEnd();
40+
if (end < 0) {
41+
Value dimSize = rewriter.create<AtenSizeIntOp>(
42+
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
43+
newEnd =
44+
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
45+
}
46+
47+
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
48+
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
49+
50+
// Create IndexPut_Op
51+
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
52+
Value range = rewriter.create<AtenArangeStartStepOp>(
53+
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
54+
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
55+
/*pin_memory=*/noneVal);
56+
57+
SmallVector<Value> indicesVector;
58+
for (auto i = 0; i < dim - 1; i++)
59+
indicesVector.push_back(noneVal);
60+
indicesVector.push_back(range);
61+
Value indices = rewriter.create<PrimListConstructOp>(
62+
op.getLoc(),
63+
Torch::ListType::get(op->getContext(),
64+
Torch::OptionalType::get(tensorType)),
65+
indicesVector);
66+
67+
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
68+
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
69+
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
70+
71+
return success();
72+
}
73+
};
74+
} // namespace
75+
76+
namespace {
77+
class RecomposeComplexOps
78+
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
79+
public:
80+
RecomposeComplexOps() = default;
81+
void runOnOperation() override {
82+
MLIRContext *context = &getContext();
83+
RewritePatternSet patterns(context);
84+
85+
// pattern.add calls go here
86+
patterns.add<RecomposeSliceCopy_>(context);
87+
88+
GreedyRewriteConfig config;
89+
config.useTopDownTraversal = true;
90+
config.maxIterations = GreedyRewriteConfig::kNoLimit;
91+
92+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
93+
config))) {
94+
return signalPassFailure();
95+
}
96+
}
97+
};
98+
} // namespace
99+
100+
std::unique_ptr<OperationPass<func::FuncOp>>
101+
mlir::torch::Torch::createRecomposeComplexOps() {
102+
return std::make_unique<RecomposeComplexOps>();
103+
}

python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,47 @@ def forward(self, x):
481481
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
482482
def NarrowVerticalTest2_basic(module, tu: TestUtils):
483483
module.forward(tu.rand(6,4))
484+
485+
# ==============================================================================
486+
487+
class SliceCopy_Module(torch.nn.Module):
488+
def __init__(self):
489+
super().__init__()
490+
491+
@export
492+
@annotate_args([
493+
None,
494+
([10, 4, 4], torch.float32, True),
495+
([4, 4, 4], torch.float32, True),
496+
])
497+
def forward(self, x, y):
498+
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
499+
xslice.copy_(y)
500+
return x
501+
502+
503+
@register_test_case(module_factory=lambda: SliceCopy_Module())
504+
def SliceCopy_Module_basic(module, tu: TestUtils):
505+
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))
506+
507+
# ==============================================================================
508+
509+
class SliceCopyNegative_Module(torch.nn.Module):
510+
def __init__(self):
511+
super().__init__()
512+
513+
@export
514+
@annotate_args([
515+
None,
516+
([-1, -1, -1], torch.float32, True),
517+
([-1, -1, -1], torch.float32, True),
518+
])
519+
def forward(self, x, y):
520+
xslice = torch.ops.aten.slice(x, 0, 2, -4, 1)
521+
xslice.copy_(y)
522+
return x
523+
524+
525+
@register_test_case(module_factory=lambda: SliceCopyNegative_Module())
526+
def SliceCopyNegative_Module_basic(module, tu: TestUtils):
527+
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))

0 commit comments

Comments
 (0)