Skip to content

Commit 160cdeb

Browse files
AmosLewisAmosLewis
authored andcommitted
[MLIR] Fix fold slice and copy int64_max bug
1 parent a744978 commit 160cdeb

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
#include "PassDetail.h"
1111

12+
#include "mlir/IR/Matchers.h"
1213
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1314
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
15+
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
1416
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
17+
#include <climits>
1518

1619
using namespace mlir;
1720
using namespace mlir::torch;
@@ -33,7 +36,14 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
3336
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
3437
return failure();
3538
int64_t end;
36-
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
39+
if (sliceOp.getEnd().getType().isa<Torch::NoneType>())
40+
end = INT64_MAX;
41+
else if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
42+
return failure();
43+
int64_t start;
44+
if (sliceOp.getStart().getType().isa<Torch::NoneType>())
45+
start = INT64_MIN;
46+
else if (!matchPattern(sliceOp.getStart(), m_TorchConstantInt(&start)))
3747
return failure();
3848

3949
Value newEnd = sliceOp.getEnd();
@@ -42,6 +52,20 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
4252
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
4353
newEnd =
4454
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
55+
} else if (end == INT64_MAX) {
56+
newEnd = rewriter.create<AtenSizeIntOp>(op.getLoc(), sliceOp.getSelf(),
57+
sliceOp.getDim());
58+
}
59+
60+
Value newStart = sliceOp.getStart();
61+
if (start == INT64_MIN) {
62+
newStart = rewriter.create<ConstantIntOp>(op.getLoc(),
63+
rewriter.getI64IntegerAttr(0));
64+
} else if (start < 0) {
65+
Value dimSize = rewriter.create<AtenSizeIntOp>(
66+
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
67+
newStart =
68+
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
4569
}
4670

4771
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
@@ -50,7 +74,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
5074
// Create IndexPut_Op
5175
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
5276
Value range = rewriter.create<AtenArangeStartStepOp>(
53-
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
77+
op.getLoc(), tensorType, newStart, newEnd, sliceOp.getStep(),
5478
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
5579
/*pin_memory=*/noneVal);
5680

0 commit comments

Comments
 (0)