Skip to content

Commit 2399017

Browse files
committed
Disable onnx.Compress op that use torch.nonzero op
1 parent 9cd4b24 commit 2399017

File tree

2 files changed

+1
-40
lines changed

2 files changed

+1
-40
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
10511051
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
10521052
rewriter.replaceOpWithNewOp<Torch::AtenIndexSelectOp>(
10531053
binder.op, resultType, operand, dimVal, indexVal);
1054-
return success();
1054+
return failure();
10551055
});
10561056
patterns.onOp(
10571057
"Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,45 +2196,6 @@ func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,
21962196

21972197
// -----
21982198

2199-
// CHECK-LABEL: func.func @test_compress
2200-
func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2201-
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %arg1 : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
2202-
// CHECK: %[[DIM:.*]] = torch.constant.int 2
2203-
// CHECK: torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32>
2204-
%0 = torch.operator "onnx.Compress"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32>
2205-
return %0 : !torch.vtensor<[2,3,2],f32>
2206-
}
2207-
2208-
// -----
2209-
2210-
// CHECK-LABEL: func.func @test_compress_default_axis
2211-
func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2212-
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 0, 1, 0, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
2213-
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[6],si64> -> !torch.vtensor<[3],si64>
2214-
// CHECK: %[[INT0:.*]] = torch.constant.int 0
2215-
// CHECK: %[[END_DIM:.*]] = torch.constant.int -1
2216-
// CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
2217-
// CHECK: torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32>
2218-
%cst = torch.vtensor.literal(dense<[0,1,0,1,0,1]> : tensor<6xsi64>) : !torch.vtensor<[6], si64>
2219-
%0 = torch.operator "onnx.Compress"(%arg0, %cst) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[3],f32>
2220-
return %0 : !torch.vtensor<[3],f32>
2221-
}
2222-
2223-
// -----
2224-
2225-
// CHECK-LABEL: func.func @test_compress_neg_axis
2226-
func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,2,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2227-
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 1]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
2228-
// CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
2229-
// CHECK: %[[DIM:.*]] = torch.constant.int 1
2230-
// CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32>
2231-
%cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>
2232-
%0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,2,4],f32>
2233-
return %0 : !torch.vtensor<[2,2,4],f32>
2234-
}
2235-
2236-
// -----
2237-
22382199
// CHECK-LABEL: func.func @test_einsum_batch_diagonal
22392200
func.func @test_einsum_batch_diagonal(%arg0: !torch.vtensor<[3,5,5],f64>) -> !torch.vtensor<[3,5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22402201
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>

0 commit comments

Comments
 (0)