From c15f1a2bd2276b2ed6e9b47fdb9b8f9b8da5b2dd Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 6 Mar 2024 17:01:05 -0800 Subject: [PATCH] [onnx] Adding lowering for `onnx.Size` operation (#2985) We can support `onnx.Size` by requesing the size of each dimensions and taking the product of the results, then packing it into a tensor. --------- Co-authored-by: Scott Todd --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 44 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 21 +++++++++ 2 files changed, 65 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index df3449939138..34282bfef531 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2032,6 +2032,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( none, none, none); return success(); }); + patterns.onOp( + "Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + auto &op = binder.op; + auto operandTy = cast(operand.getType()); + + if (!operandTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input rank unknown"); + + llvm::SmallVector dims; + int64_t rank = operandTy.getSizes().size(); + for (int i = 0; i < rank; ++i) { + auto iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value dim = rewriter.create( + loc, rewriter.getType(), operand, iv); + dims.push_back(dim); + } + + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + if (dims.empty()) { + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp( + op, resultType, one, none, none, cstFalse); + return success(); + } + + Value prod = dims[0]; + for (int i = 1, s = dims.size(); i < s; ++i) + prod = rewriter.create(loc, prod, dims[i]); + + rewriter.replaceOpWithNewOp( + op, resultType, prod, none, none, cstFalse); + return success(); + }); patterns.onOp( "Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 977c557739b5..bba74b6d9877 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1649,3 +1649,24 @@ func.func @test_sign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, %0 = torch.operator "onnx.Sign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_size +func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64} { + // CHECK-DAG %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG %[[D0:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK-DAG %[[D1:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK-DAG %[[D2:.+]] = torch.aten.size.int %arg0, %[[INT2]] + // CHECK-DAG %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG %[[NONE:.+]] = torch.constant.none + // CHECK-DAG %[[MUL0:.+]] = torch.aten.mul.int %[[D0]], %[[D1]] + // CHECK-DAG %[[MUL1:.+]] = torch.aten.mul.int %[[MUL0]], %[[D3]] + // CHECK-DAG %[[TENSOR:.+]] = torch.aten.tensor.int %[[MUL1]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK return %[[TENSOR]] + %0 = torch.operator "onnx.Size"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> + return %0 : !torch.vtensor<[],si32> +} +