@@ -2196,45 +2196,6 @@ func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,
2196
2196
2197
2197
// -----
2198
2198
2199
- // CHECK-LABEL: func.func @test_compress
2200
- func.func @test_compress (%arg0: !torch.vtensor <[2 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[3 ], si64 >) -> !torch.vtensor <[2 ,3 ,2 ], f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 9 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
2201
- // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %arg1 : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
2202
- // CHECK: %[[DIM:.*]] = torch.constant.int 2
2203
- // CHECK: torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32>
2204
- %0 = torch.operator " onnx.Compress" (%arg0 , %arg1 ) {torch.onnx.axis = 2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.vtensor <[3 ],si64 >) -> !torch.vtensor <[2 ,3 ,2 ],f32 >
2205
- return %0 : !torch.vtensor <[2 ,3 ,2 ],f32 >
2206
- }
2207
-
2208
- // -----
2209
-
2210
- // CHECK-LABEL: func.func @test_compress_default_axis
2211
- func.func @test_compress_default_axis (%arg0: !torch.vtensor <[2 ,3 ],f32 >) -> !torch.vtensor <[3 ], f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 9 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
2212
- // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 0, 1, 0, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
2213
- // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[6],si64> -> !torch.vtensor<[3],si64>
2214
- // CHECK: %[[INT0:.*]] = torch.constant.int 0
2215
- // CHECK: %[[END_DIM:.*]] = torch.constant.int -1
2216
- // CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
2217
- // CHECK: torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32>
2218
- %cst = torch.vtensor.literal (dense <[0 ,1 ,0 ,1 ,0 ,1 ]> : tensor <6 xsi64 >) : !torch.vtensor <[6 ], si64 >
2219
- %0 = torch.operator " onnx.Compress" (%arg0 , %cst ) : (!torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[6 ],si64 >) -> !torch.vtensor <[3 ],f32 >
2220
- return %0 : !torch.vtensor <[3 ],f32 >
2221
- }
2222
-
2223
- // -----
2224
-
2225
- // CHECK-LABEL: func.func @test_compress_neg_axis
2226
- func.func @test_compress_neg_axis (%arg0: !torch.vtensor <[2 ,3 ,4 ],f32 >) -> !torch.vtensor <[2 ,2 ,4 ], f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 9 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
2227
- // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 1]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
2228
- // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
2229
- // CHECK: %[[DIM:.*]] = torch.constant.int 1
2230
- // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32>
2231
- %cst = torch.vtensor.literal (dense <[0 ,1 ,1 ]> : tensor <3 xsi64 >) : !torch.vtensor <[3 ], si64 >
2232
- %0 = torch.operator " onnx.Compress" (%arg0 , %cst ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.vtensor <[3 ],si64 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 >
2233
- return %0 : !torch.vtensor <[2 ,2 ,4 ],f32 >
2234
- }
2235
-
2236
- // -----
2237
-
2238
2199
// CHECK-LABEL: func.func @test_einsum_batch_diagonal
2239
2200
func.func @test_einsum_batch_diagonal (%arg0: !torch.vtensor <[3 ,5 ,5 ],f64 >) -> !torch.vtensor <[3 ,5 ],f64 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
2240
2201
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>
0 commit comments