-
Notifications
You must be signed in to change notification settings - Fork 611
Fold slice+copy_ into index_put_ #1901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "PassDetail.h" | ||
|
||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(AtenCopy_Op op, | ||
PatternRewriter &rewriter) const override { | ||
if (!op.getSelf().getDefiningOp() || | ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp())) | ||
return failure(); | ||
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp()); | ||
|
||
// Get indices | ||
int64_t dim; | ||
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim))) | ||
return failure(); | ||
int64_t end; | ||
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end))) | ||
return failure(); | ||
|
||
Value newEnd = sliceOp.getEnd(); | ||
if (end < 0) { | ||
Value dimSize = rewriter.create<AtenSizeIntOp>( | ||
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); | ||
newEnd = | ||
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd()); | ||
} | ||
|
||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc()); | ||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false); | ||
|
||
// Create IndexPut_Op | ||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>(); | ||
Value range = rewriter.create<AtenArangeStartStepOp>( | ||
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(), | ||
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, | ||
/*pin_memory=*/noneVal); | ||
|
||
SmallVector<Value> indicesVector; | ||
for (auto i = 0; i < dim - 1; i++) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the correct indices should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After convert the dim-1 to the dim, get the indices like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In [1]: import torch
In [2]: torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]),
...: (None, torch.tensor([1, 2, 3]),),
...: torch.tensor([[4, 5, 6]]))
Out[2]: tensor([[0, 4, 5, 6]]) That should work as expected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it correct. |
||
indicesVector.push_back(noneVal); | ||
indicesVector.push_back(range); | ||
Value indices = rewriter.create<PrimListConstructOp>( | ||
op.getLoc(), | ||
Torch::ListType::get(op->getContext(), | ||
Torch::OptionalType::get(tensorType)), | ||
indicesVector); | ||
|
||
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>( | ||
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(), | ||
/*accumulate=*/falseVal, /*unsafe=*/falseVal); | ||
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace { | ||
class RecomposeComplexOps | ||
: public DecomposeComplexOpsBase<RecomposeComplexOps> { | ||
public: | ||
RecomposeComplexOps() = default; | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
// pattern.add calls go here | ||
patterns.add<RecomposeSliceCopy_>(context); | ||
|
||
GreedyRewriteConfig config; | ||
config.useTopDownTraversal = true; | ||
config.maxIterations = GreedyRewriteConfig::kNoLimit; | ||
|
||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), | ||
config))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::Torch::createRecomposeComplexOps() { | ||
return std::make_unique<RecomposeComplexOps>(); | ||
} |
Uh oh!
There was an error while loading. Please reload this page.