diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 3a43d375..8c61b8be 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -562,20 +562,25 @@ def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> { let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; } -def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [Pure, TosaExtension]> { +def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [ + Pure, TosaExtension, + InferTensorTypeAdaptor]> { let summary = "Reduce Mean operation"; let description = [{ This operation is equivalent to `onnx.ReduceMean` and computes the mean of the input tensor's elements along the provided axes. }]; + let arguments = (ins AnyRankedTensor:$input, DenseI64ArrayAttr:$axes, I64Attr:$keepdims ); + let results = (outs AnyRankedTensor:$output ); + let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; } diff --git a/lib/Conversion/XTenNNToTorch.cpp b/lib/Conversion/XTenNNToTorch.cpp index ed96f980..fa6b79c6 100644 --- a/lib/Conversion/XTenNNToTorch.cpp +++ b/lib/Conversion/XTenNNToTorch.cpp @@ -220,6 +220,21 @@ convTranspose2dToTorch(ConvTransposeOp op, ConvTransposeOp::Adaptor adaptor, ->getResults(); } +std::optional +reduceMeanToTorch(ReduceMeanOp op, ReduceMeanOp::Adaptor adaptor, + ArrayRef types, ValueRange values, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto noneConst = rewriter.create(loc); + auto keepdims = + rewriter.create(loc, adaptor.getKeepdims()); + auto axes = Torch::toTorchList(loc, rewriter, adaptor.getAxes().vec()); + return rewriter + .create(loc, types[0], values[0], axes, keepdims, + noneConst) + ->getResults(); +} + std::optional resizeToTorch(ResizeOp op, ResizeOp::Adaptor adaptor, ArrayRef types, ValueRange values, ConversionPatternRewriter &rewriter) { @@ -439,6 +454,7 @@ struct ConvertXTenNNToTorch patterns.add>(context); patterns.add>( context); + patterns.add>(context); if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) signalPassFailure(); } diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 381a8342..cc4c67af 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -26,6 +27,7 @@ #include "xten/Dialect/XTenNN/IR/XTenNNBase.h" #include "xten/Dialect/XTenNN/IR/XTenNNOps.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" +#include using namespace mlir; using namespace amd::xten_nn; @@ -264,7 +266,9 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) { return parseEnclaveOp(p, result); } -void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); } +void SubgraphOp::print(OpAsmPrinter &p) { + printEnclaveOp(p, *this); +} LogicalResult SubgraphOp::verify() { Block *optBody = this->getOptionalEnclaveBody(); @@ -593,3 +597,48 @@ bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) { getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]); return sameElementType && succeeded(verifyCompatibleShapes(l, r)); } + +LogicalResult ReduceMeanOp::inferReturnTypeComponents( + MLIRContext * /*context*/, std::optional location, + ReduceMeanOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + + auto inTy = cast(adaptor.getInput().getType()); + auto keepDims = adaptor.getKeepdims(); + auto axes = adaptor.getAxes(); + + // Sanitize axes + llvm::SmallVector newAxes; + for (auto axis : axes) { + // onnx spec: axis: [-r, r-1] + if (axis < -inTy.getRank() || axis >= inTy.getRank()) { + return emitOptionalError(location, + "expected axis to be within [-rank,rank) (where " + "rank is the rank of the input)"); + } + + // normalize axis: [0, r) + if (axis < 0) { + axis += inTy.getRank(); + } + + assert((axis >= 0 && axis < inTy.getRank()) && "axis has invalid value"); + newAxes.push_back(axis); + } + + SmallVector outputShape; + auto inputShape = inTy.getShape(); + for (auto [idx, dim] : llvm::enumerate(inputShape)) { + if (llvm::is_contained(axes, idx)) { + if (keepDims) { + outputShape.push_back(1); + } + } else { + outputShape.push_back(dim); + } + } + + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inTy.getElementType())); + return success(); +} \ No newline at end of file diff --git a/test/Conversion/XTenNNToTorch/reduce_mean.mlir b/test/Conversion/XTenNNToTorch/reduce_mean.mlir new file mode 100644 index 00000000..45d4a1e7 --- /dev/null +++ b/test/Conversion/XTenNNToTorch/reduce_mean.mlir @@ -0,0 +1,288 @@ +// RUN: aten-opt --convert-xtennn-to-torch -split-input-file %s | FileCheck %s +// REQUIRES: torch + +func.func @reduce_mean_one_axis_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> + return %0 : tensor<4x512x1x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_one_axis_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,1,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,1,8],f32> -> tensor<4x512x1x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x1x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_three_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> + return %0 : tensor<4x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_three_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1,1],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4,1,1,1],f32> -> tensor<4x1x1x1xf32> +// CHECK: return %[[VAL_9]] : tensor<4x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_all_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_all_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_one_axis(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> + return %0 : tensor<4x512x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_one_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,8],f32> -> tensor<4x512x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_three_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_three_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: return %[[VAL_9]] : tensor<4xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_all_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @reduce_mean_all_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[],f32> -> tensor +// CHECK: return %[[VAL_10]] : tensor +// CHECK: } + +// ----- + +func.func @reduce_mean_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> + return %0 : tensor<4x512x256x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_noop_with_empty_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[4,512,256,8],f32> -> tensor<4x512x256x8xf32> +// CHECK: return %[[VAL_6]] : tensor<4x512x256x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_one_axis_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> + return %0 : tensor<4x512x1x8xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_one_axis_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,1,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,1,8],f32> -> tensor<4x512x1x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x1x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_three_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> + return %0 : tensor<4x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_three_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1,1],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4,1,1,1],f32> -> tensor<4x1x1x1xf32> +// CHECK: return %[[VAL_9]] : tensor<4x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_all_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_all_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_one_axis(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> + return %0 : tensor<4x512x8xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_one_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,8],f32> -> tensor<4x512x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_three_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_three_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: return %[[VAL_9]] : tensor<4xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_all_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @reduce_meanv13_all_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[],f32> -> tensor +// CHECK: return %[[VAL_10]] : tensor +// CHECK: } + +// ----- + +func.func @reduce_meanv13_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_noop_with_empty_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: }