Skip to content

Commit 8ce63aa

Browse files
AmosLewisAmosLewis
authored andcommitted
[MLIR] Fix fold slice and copy int64_max bug
1 parent 6973abb commit 8ce63aa

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
#include "PassDetail.h"
1111

12+
#include "mlir/IR/Matchers.h"
1213
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1314
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
15+
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
1416
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
1517
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
18+
#include <climits>
1619

1720
using namespace mlir;
1821
using namespace mlir::torch;
@@ -33,8 +36,30 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
3336
int64_t dim;
3437
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
3538
return failure();
39+
40+
// TODO Comparing directly to INT64_MAX/INT64_MIN seems fragile.
41+
// This is a potential general way of implementing the clamping in terms of other torch ops.
42+
// https://github.com/llvm/torch-mlir/pull/2005
43+
//
44+
// def to_valid_dim(dim, max_dim):
45+
// dim = torch.ops.prim.min(dim, max_dim)
46+
// dim = torch.ops.prim.max(dim, -max_dim)
47+
// is_neg = torch.ops.aten.lt(dim, 0)
48+
// is_neg_int = torch.ops.aten.Int.bool(is_neg)
49+
// return (dim + max_dim) * is_neg_int + dim * (1 - is_neg_int)
50+
//
51+
// dim_size = torch.ops.aten.size.int(slice, slice.get_dim())
52+
// start = to_valid_dim(slice.get_start(), dim_size)
53+
// end = to_valid_dim(slice.get_end(), dim_size)
3654
int64_t end;
37-
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
55+
if (sliceOp.getEnd().getType().isa<Torch::NoneType>())
56+
end = INT64_MAX;
57+
else if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
58+
return failure();
59+
int64_t start;
60+
if (sliceOp.getStart().getType().isa<Torch::NoneType>())
61+
start = INT64_MIN;
62+
else if (!matchPattern(sliceOp.getStart(), m_TorchConstantInt(&start)))
3863
return failure();
3964

4065
Value newEnd = sliceOp.getEnd();
@@ -43,6 +68,20 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
4368
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
4469
newEnd =
4570
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
71+
} else if (end == INT64_MAX) {
72+
newEnd = rewriter.create<AtenSizeIntOp>(op.getLoc(), sliceOp.getSelf(),
73+
sliceOp.getDim());
74+
}
75+
76+
Value newStart = sliceOp.getStart();
77+
if (start == INT64_MIN) {
78+
newStart = rewriter.create<ConstantIntOp>(op.getLoc(),
79+
rewriter.getI64IntegerAttr(0));
80+
} else if (start < 0) {
81+
Value dimSize = rewriter.create<AtenSizeIntOp>(
82+
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
83+
newStart =
84+
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
4685
}
4786

4887
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
@@ -51,12 +90,12 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
5190
// Create IndexPut_Op
5291
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
5392
Value range = rewriter.create<AtenArangeStartStepOp>(
54-
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
93+
op.getLoc(), tensorType, newStart, newEnd, sliceOp.getStep(),
5594
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
5695
/*pin_memory=*/noneVal);
5796

5897
SmallVector<Value> indicesVector;
59-
for (auto i = 0; i < dim - 1; i++)
98+
for (auto i = 0; i < dim; i++)
6099
indicesVector.push_back(noneVal);
61100
indicesVector.push_back(range);
62101
Value indices = rewriter.create<PrimListConstructOp>(
@@ -105,7 +144,7 @@ class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
105144

106145
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
107146
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
108-
SmallVector<Value> indicesVector(dim - 1, noneVal);
147+
SmallVector<Value> indicesVector(dim, noneVal);
109148
indicesVector.push_back(indexTensor);
110149

111150
Value indices = rewriter.create<PrimListConstructOp>(

0 commit comments

Comments
 (0)