9
9
10
10
#include " PassDetail.h"
11
11
12
+ #include " mlir/IR/Matchers.h"
12
13
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
14
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
15
+ #include " torch-mlir/Dialect/Torch/IR/TorchTypes.h"
14
16
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
17
+ #include < climits>
15
18
16
19
using namespace mlir ;
17
20
using namespace mlir ::torch;
@@ -33,7 +36,14 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
33
36
if (!matchPattern (sliceOp.getDim (), m_TorchConstantInt (&dim)))
34
37
return failure ();
35
38
int64_t end;
36
- if (!matchPattern (sliceOp.getEnd (), m_TorchConstantInt (&end)))
39
+ if (sliceOp.getEnd ().getType ().isa <Torch::NoneType>())
40
+ end = INT64_MAX;
41
+ else if (!matchPattern (sliceOp.getEnd (), m_TorchConstantInt (&end)))
42
+ return failure ();
43
+ int64_t start;
44
+ if (sliceOp.getStart ().getType ().isa <Torch::NoneType>())
45
+ start = INT64_MIN;
46
+ else if (!matchPattern (sliceOp.getStart (), m_TorchConstantInt (&start)))
37
47
return failure ();
38
48
39
49
Value newEnd = sliceOp.getEnd ();
@@ -42,6 +52,20 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
42
52
op.getLoc (), sliceOp.getSelf (), sliceOp.getDim ());
43
53
newEnd =
44
54
rewriter.create <AtenAddIntOp>(op.getLoc (), dimSize, sliceOp.getEnd ());
55
+ } else if (end == INT64_MAX) {
56
+ newEnd = rewriter.create <AtenSizeIntOp>(op.getLoc (), sliceOp.getSelf (),
57
+ sliceOp.getDim ());
58
+ }
59
+
60
+ Value newStart = sliceOp.getStart ();
61
+ if (start == INT64_MIN) {
62
+ newStart = rewriter.create <ConstantIntOp>(op.getLoc (),
63
+ rewriter.getI64IntegerAttr (0 ));
64
+ } else if (start < 0 ) {
65
+ Value dimSize = rewriter.create <AtenSizeIntOp>(
66
+ op.getLoc (), sliceOp.getSelf (), sliceOp.getDim ());
67
+ newStart =
68
+ rewriter.create <AtenAddIntOp>(op.getLoc (), dimSize, sliceOp.getEnd ());
45
69
}
46
70
47
71
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
@@ -50,7 +74,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
50
74
// Create IndexPut_Op
51
75
BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
52
76
Value range = rewriter.create <AtenArangeStartStepOp>(
53
- op.getLoc (), tensorType, sliceOp. getStart () , newEnd, sliceOp.getStep (),
77
+ op.getLoc (), tensorType, newStart , newEnd, sliceOp.getStep (),
54
78
/* dtype=*/ noneVal, /* layout=*/ noneVal, /* device=*/ noneVal,
55
79
/* pin_memory=*/ noneVal);
56
80
0 commit comments