Skip to content

Commit b2ae471

Browse files
committed
Fold slice+copy_ into index_put_
1 parent 62250da commit b2ae471

File tree

6 files changed

+106
-1
lines changed

6 files changed

+106
-1
lines changed

e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@
837837
"DropoutTrainModule_basic",
838838
"StdCorrectionKeepDimModule_basic",
839839
"StdCorrectionNoneModule_basic",
840+
"SliceCopy_Module_basic",
840841
"VarBiasedModule_basic",
841842
"VarCorrectionAllDimReduceModule_basic",
842843
"VarCorrectionEmptyDimModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6380,6 +6380,7 @@ def Torch_AtenCopy_Op : Torch_Op<"aten.copy_", [
63806380
printDefaultTorchOp(printer, *this, 3, 1);
63816381
}
63826382
}];
6383+
let hasCanonicalizer = 1;
63836384
}
63846385

63856386
def Torch_Aten_ToCopyOp : Torch_Op<"aten._to_copy", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
11+
#include "mlir/IR/BuiltinTypes.h"
12+
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
1113
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1214

1315
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -2134,6 +2136,59 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
21342136
return list.getElements()[0];
21352137
}
21362138

2139+
//===----------------------------------------------------------------------===//
2140+
// AtenCopy_Op
2141+
//===----------------------------------------------------------------------===//
2142+
2143+
void AtenCopy_Op::getCanonicalizationPatterns(RewritePatternSet &patterns,
2144+
MLIRContext *context) {
2145+
patterns.add(+[](AtenCopy_Op op, PatternRewriter &rewriter) {
2146+
if (!op.getSelf().getDefiningOp() ||
2147+
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
2148+
return failure();
2149+
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
2150+
2151+
// Get indices
2152+
int64_t dim;
2153+
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
2154+
return failure();
2155+
int64_t end;
2156+
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)) || end < 0)
2157+
return failure();
2158+
int64_t step;
2159+
if (!matchPattern(sliceOp.getStep(), m_TorchConstantInt(&step)) ||
2160+
step != 1)
2161+
return failure();
2162+
2163+
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
2164+
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
2165+
2166+
// Create IndexPut_Op
2167+
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
2168+
Value range = rewriter.create<AtenArangeStartStepOp>(
2169+
op.getLoc(), tensorType, sliceOp.getStart(), sliceOp.getEnd(),
2170+
sliceOp.getStep(),
2171+
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
2172+
/*pin_memory=*/noneVal);
2173+
2174+
SmallVector<Value> indicesVector;
2175+
for (auto i = 0; i < dim - 1; i++)
2176+
indicesVector.push_back(noneVal);
2177+
indicesVector.push_back(range);
2178+
Value indices = rewriter.create<PrimListConstructOp>(
2179+
op.getLoc(),
2180+
Torch::ListType::get(op->getContext(),
2181+
Torch::OptionalType::get(tensorType)),
2182+
indicesVector);
2183+
2184+
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
2185+
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
2186+
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
2187+
2188+
return success();
2189+
});
2190+
}
2191+
21372192
//===----------------------------------------------------------------------===//
21382193
// AtenSliceTensorOp
21392194
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ def emit_with_mutating_variants(key, **kwargs):
234234
emitter_td,
235235
traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
236236

237+
def emit_as_mutating_variant(key, **kwargs):
238+
emit_op(registry[key],
239+
emitter_td,
240+
traits=["IsTrailingUnderscoreInplaceVariant"],
241+
**kwargs)
242+
237243
# ==========================================================================
238244
# `aten::` namespace.
239245
# ==========================================================================
@@ -461,7 +467,8 @@ def emit_with_mutating_variants(key, **kwargs):
461467
emit("aten::clone : (Tensor, int?) -> (Tensor)")
462468
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
463469
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
464-
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
470+
emit("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
471+
emit_as_mutating_variant("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)", has_canonicalizer=True)
465472
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
466473
emit("aten::detach : (Tensor) -> (Tensor)")
467474
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")

python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,25 @@ def forward(self, x):
481481
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
482482
def NarrowVerticalTest2_basic(module, tu: TestUtils):
483483
module.forward(tu.rand(6,4))
484+
485+
# ==============================================================================
486+
487+
class SliceCopy_Module(torch.nn.Module):
488+
def __init__(self):
489+
super().__init__()
490+
491+
@export
492+
@annotate_args([
493+
None,
494+
([10, 4, 4], torch.float32, True),
495+
([4, 4, 4], torch.float32, True),
496+
])
497+
def forward(self, x, y):
498+
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
499+
xslice.copy_(y)
500+
return x
501+
502+
503+
@register_test_case(module_factory=lambda: SliceCopy_Module())
504+
def SliceCopy_Module_basic(module, tu: TestUtils):
505+
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))

test/Dialect/Torch/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,3 +1838,22 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
18381838
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
18391839
return %0 : !torch.vtensor<[4],f32>
18401840
}
1841+
1842+
// CHECK-LABEL: func.func @torch.aten.slice.tensor$slice_plus_copy
1843+
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[10,4,4],f32>
1844+
// CHECK-SAME: %[[ARG1:.+]]: !torch.vtensor<[4,4,4],f32>
1845+
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT0]], %[[INT2]], %[[INT6]], %[[INT1]] : !torch.vtensor<[10,4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32>
1846+
// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT2]], %[[INT6]], %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4,4,4],f32>
1847+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARANGE]] : (!torch.vtensor<[4,4,4],f32>) -> !torch.list<optional<vtensor<[4,4,4],f32>>>
1848+
// CHECK: %[[INDEXPUT:.*]] = torch.aten._index_put_impl_ %[[ARG0]], %[[LIST]], %[[ARG1]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[10,4,4],f32>, !torch.list<optional<vtensor<[4,4,4],f32>>>, !torch.vtensor<[4,4,4],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[4,4,4],f32>
1849+
// CHECK: return %[[ARG0]] : !torch.vtensor<[10,4,4],f32>
1850+
func.func @torch.aten.slice.tensor$slice_plus_copy(%arg0: !torch.vtensor<[10,4,4],f32>, %arg1: !torch.vtensor<[4,4,4],f32>) -> !torch.vtensor<[10,4,4],f32> {
1851+
%false = torch.constant.bool false
1852+
%int0 = torch.constant.int 0
1853+
%int2 = torch.constant.int 2
1854+
%int6 = torch.constant.int 6
1855+
%int1 = torch.constant.int 1
1856+
%1 = torch.aten.slice.Tensor %arg0, %int0, %int2, %int6, %int1 : !torch.vtensor<[10,4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32>
1857+
%2 = torch.aten.copy_ %1, %arg1, %false : !torch.vtensor<[4,4,4],f32>, !torch.vtensor<[4,4,4],f32>, !torch.bool -> !torch.vtensor<[4,4,4],f32>
1858+
return %arg0 : !torch.vtensor<[10,4,4],f32>
1859+
}

0 commit comments

Comments
 (0)