@@ -1838,3 +1838,22 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
1838
1838
%0 = torch.aten.slice.Tensor %arg0 , %int0 , %int0 , %int -1 , %int1 : !torch.vtensor <[4 ], f32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[4 ], f32 >
1839
1839
return %0 : !torch.vtensor <[4 ],f32 >
1840
1840
}
1841
+
1842
+ // CHECK-LABEL: func.func @torch.aten.slice.tensor$slice_plus_copy
1843
+ // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[10,4,4],f32>
1844
+ // CHECK-SAME: %[[ARG1:.+]]: !torch.vtensor<[4,4,4],f32>
1845
+ // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT0]], %[[INT2]], %[[INT6]], %[[INT1]] : !torch.vtensor<[10,4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32>
1846
+ // CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT2]], %[[INT6]], %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4,4,4],f32>
1847
+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARANGE]] : (!torch.vtensor<[4,4,4],f32>) -> !torch.list<optional<vtensor<[4,4,4],f32>>>
1848
+ // CHECK: %[[INDEXPUT:.*]] = torch.aten._index_put_impl_ %[[ARG0]], %[[LIST]], %[[ARG1]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[10,4,4],f32>, !torch.list<optional<vtensor<[4,4,4],f32>>>, !torch.vtensor<[4,4,4],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[4,4,4],f32>
1849
+ // CHECK: return %[[ARG0]] : !torch.vtensor<[10,4,4],f32>
1850
+ func.func @torch.aten.slice.tensor$slice_plus_copy (%arg0: !torch.vtensor <[10 ,4 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ,4 ,4 ],f32 >) -> !torch.vtensor <[10 ,4 ,4 ],f32 > {
1851
+ %false = torch.constant.bool false
1852
+ %int0 = torch.constant.int 0
1853
+ %int2 = torch.constant.int 2
1854
+ %int6 = torch.constant.int 6
1855
+ %int1 = torch.constant.int 1
1856
+ %1 = torch.aten.slice.Tensor %arg0 , %int0 , %int2 , %int6 , %int1 : !torch.vtensor <[10 ,4 ,4 ],f32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[4 ,4 ,4 ],f32 >
1857
+ %2 = torch.aten.copy_ %1 , %arg1 , %false : !torch.vtensor <[4 ,4 ,4 ],f32 >, !torch.vtensor <[4 ,4 ,4 ],f32 >, !torch.bool -> !torch.vtensor <[4 ,4 ,4 ],f32 >
1858
+ return %arg0 : !torch.vtensor <[10 ,4 ,4 ],f32 >
1859
+ }
0 commit comments