9
9
10
10
#include " PassDetail.h"
11
11
12
+ #include " mlir/IR/Matchers.h"
12
13
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
14
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
15
+ #include " torch-mlir/Dialect/Torch/IR/TorchTypes.h"
14
16
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
17
#include " torch-mlir/Dialect/Torch/Utils/Utils.h"
18
+ #include < climits>
16
19
17
20
using namespace mlir ;
18
21
using namespace mlir ::torch;
@@ -33,8 +36,30 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
33
36
int64_t dim;
34
37
if (!matchPattern (sliceOp.getDim (), m_TorchConstantInt (&dim)))
35
38
return failure ();
39
+
40
+ // TODO Comparing directly to INT64_MAX/INT64_MIN seems fragile.
41
+ // This is a potential general way of implementing the clamping in terms of other torch ops.
42
+ // https://github.com/llvm/torch-mlir/pull/2005
43
+ //
44
+ // def to_valid_dim(dim, max_dim):
45
+ // dim = torch.ops.prim.min(dim, max_dim)
46
+ // dim = torch.ops.prim.max(dim, -max_dim)
47
+ // is_neg = torch.ops.aten.lt(dim, 0)
48
+ // is_neg_int = torch.ops.aten.Int.bool(is_neg)
49
+ // return (dim + max_dim) * is_neg_int + dim * (1 - is_neg_int)
50
+ //
51
+ // dim_size = torch.ops.aten.size.int(slice, slice.get_dim())
52
+ // start = to_valid_dim(slice.get_start(), dim_size)
53
+ // end = to_valid_dim(slice.get_end(), dim_size)
36
54
int64_t end;
37
- if (!matchPattern (sliceOp.getEnd (), m_TorchConstantInt (&end)))
55
+ if (sliceOp.getEnd ().getType ().isa <Torch::NoneType>())
56
+ end = INT64_MAX;
57
+ else if (!matchPattern (sliceOp.getEnd (), m_TorchConstantInt (&end)))
58
+ return failure ();
59
+ int64_t start;
60
+ if (sliceOp.getStart ().getType ().isa <Torch::NoneType>())
61
+ start = INT64_MIN;
62
+ else if (!matchPattern (sliceOp.getStart (), m_TorchConstantInt (&start)))
38
63
return failure ();
39
64
40
65
Value newEnd = sliceOp.getEnd ();
@@ -43,6 +68,20 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
43
68
op.getLoc (), sliceOp.getSelf (), sliceOp.getDim ());
44
69
newEnd =
45
70
rewriter.create <AtenAddIntOp>(op.getLoc (), dimSize, sliceOp.getEnd ());
71
+ } else if (end == INT64_MAX) {
72
+ newEnd = rewriter.create <AtenSizeIntOp>(op.getLoc (), sliceOp.getSelf (),
73
+ sliceOp.getDim ());
74
+ }
75
+
76
+ Value newStart = sliceOp.getStart ();
77
+ if (start == INT64_MIN) {
78
+ newStart = rewriter.create <ConstantIntOp>(op.getLoc (),
79
+ rewriter.getI64IntegerAttr (0 ));
80
+ } else if (start < 0 ) {
81
+ Value dimSize = rewriter.create <AtenSizeIntOp>(
82
+ op.getLoc (), sliceOp.getSelf (), sliceOp.getDim ());
83
+ newStart =
84
+ rewriter.create <AtenAddIntOp>(op.getLoc (), dimSize, sliceOp.getEnd ());
46
85
}
47
86
48
87
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
@@ -51,7 +90,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
51
90
// Create IndexPut_Op
52
91
BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
53
92
Value range = rewriter.create <AtenArangeStartStepOp>(
54
- op.getLoc (), tensorType, sliceOp. getStart () , newEnd, sliceOp.getStep (),
93
+ op.getLoc (), tensorType, newStart , newEnd, sliceOp.getStep (),
55
94
/* dtype=*/ noneVal, /* layout=*/ noneVal, /* device=*/ noneVal,
56
95
/* pin_memory=*/ noneVal);
57
96
0 commit comments