diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index ecec4882c5523..907eec3814475 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -22,7 +22,6 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include #include @@ -618,9 +617,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenGeluOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { + AtenGeluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.self(); auto inputTy = input.getType().template dyn_cast(); @@ -641,7 +639,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } // namespace - // AtenErfOp namespace { template <> diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 3c036e7ef8d8f..c1fdef423ba7a 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo GatherOp.cpp ViewLikeOps.cpp ReductionOp.cpp + PoolingOp.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo diff --git a/lib/Conversion/TorchToMhlo/PoolingOp.cpp b/lib/Conversion/TorchToMhlo/PoolingOp.cpp new file mode 100644 index 0000000000000..ab28e98aeca8f --- /dev/null +++ b/lib/Conversion/TorchToMhlo/PoolingOp.cpp @@ -0,0 +1,557 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementTy); + // Avg pooling + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getZero( + elementTy.cast().getFloatSemantics(), + /*negative=*/false)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + // Max pooling + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getLargest( + elementTy.cast().getFloatSemantics(), + /*negative=*/true)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, + {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + op->emitError("unimplemented lowering in AtenPoolingOp"); + return nullptr; +} + +namespace { +template +class ConvertAtenPoolingOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +// AtenMaxPool2dOp +namespace { +template <> +LogicalResult ConvertAtenPoolingOp::matchAndRewrite( + AtenMaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + + auto inputRank = inputTy.getRank(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + + if (inputRank <= 2) { + return op.emitError( + "max_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + auto reduceWindowOp = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.body().emplaceBlock(); + + auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArg = block.args_begin(); + auto secondArg = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value result = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), result); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} +} // namespace + +// AtenMaxPool2dWithIndicesOp +namespace { +template <> +LogicalResult ConvertAtenPoolingOp::matchAndRewrite( + AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + auto outValTy = + getTypeConverter()->convertType(op.getType(0)).cast(); + auto outIdxTy = + getTypeConverter()->convertType(op.getType(1)).cast(); + + if (inputRank <= 2) { + return op.emitError( + "max_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + + auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + SmallVector initIndexShapeVec; + for (int64_t i = 0; i < inputRank - 2; i++) + initIndexShapeVec.push_back(inputShapeVec[i]); + initIndexShapeVec.push_back(rewriter.create( + op->getLoc(), inputShapeVec[inputRank - 1], + inputShapeVec[inputRank - 2])); + auto initIndexShapeTensor = rewriter.create( + op->getLoc(), initIndexShapeVec); + + SmallVector initIndexShapeForType(inputShape.begin(), + inputShape.end() - 2); + if (inputShape[inputRank - 1] == ShapedType::kDynamicSize || + inputShape[inputRank - 2] == ShapedType::kDynamicSize) { + initIndexShapeForType.push_back(ShapedType::kDynamicSize); + } else { + initIndexShapeForType.push_back(inputShape[inputRank - 1] * + inputShape[inputRank - 2]); + } + + auto initIndexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(initIndexShapeForType, + rewriter.getI64Type()), + initIndexShapeTensor, static_cast(inputRank - 2)) + .getResult(); + + auto indexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + initIndexTensor, inputShapeTensor) + .getResult(); + + Value initIdx = + mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, + windowDimensions, windowStrides, baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.body().emplaceBlock(); + + // Add bb argument + auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); + auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); + auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + auto *firstValArg = block.args_begin(); + auto *firstIdxArg = std::next(firstValArg); + auto *secondValArg = std::next(firstIdxArg); + auto *secondIdxArg = std::next(secondValArg); + + mhlo::ComparisonTypeAttr compareTypeAttr; + if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::FLOAT); + } else if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::SIGNED); + } + mhlo::ComparisonDirectionAttr compareGeDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonDirectionAttr compareEqDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::EQ); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value compareGeResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + Value retValResult = rewriter.create( + op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + + // Get smaller index if compared values are equal. + Value compareEqResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareEqDirectionAttr, compareTypeAttr); + Value minIdx = + rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = rewriter.create( + op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = rewriter.create( + op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + rewriter.create( + op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} +} // namespace + +// AtenAvgPool2dOp +namespace { +template <> +LogicalResult ConvertAtenPoolingOp::matchAndRewrite( + AtenAvgPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + auto inputRank = inputTy.getRank(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + auto outShape = outTy.getShape(); + + if (inputRank <= 2) { + return op.emitError( + "avg_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride; + bool ceilMode = false; + bool countIncludePad = true; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + if (!(matchPattern(op.count_include_pad(), + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); + } + if (succeeded(checkNotNone(rewriter, op, op.divisor_override()))) { + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + + auto reduceWindowSum = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &sumBlock = reduceWindowSum.body().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + auto *firstArg = sumBlock.args_begin(); + auto secondArg = sumBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); + + Value sumResult = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); + } + + // Use kernel size as the divisor + if (countIncludePad) { + Value divisor = mhlo::getConstTensor( + rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) + .getValue(); + divisor = mhlo::promoteType(rewriter, divisor, outTy); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); + } + + // Use another mhlo.ReduceWindowOp to get the divisor + Value windowSizeConst = + mhlo::getConstTensor(rewriter, op, {1.0}, {}).getValue(); + windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); + auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + windowSizeConst = rewriter.create( + op->getLoc(), + RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), + windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + + Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + auto reduceWindowSize = rewriter.create( + op->getLoc(), RankedTensorType::get(outShape, inputElemTy), + windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, + windowDilations, pad); + + Block &sizeBlock = reduceWindowSize.body().emplaceBlock(); + + // Add bb argument + blockArgumentType = RankedTensorType::get({}, inputElemTy); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + firstArg = sizeBlock.args_begin(); + secondArg = sizeBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); + + Value sumResult = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); + } + + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); + return success(); +} + +} // namespace + +void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add>(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); +} diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index c84cec6380c69..e69b4d47f3c88 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -28,6 +28,11 @@ void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); + +void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + } // namespace torch_to_mhlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 5007a8a26d33f..a724b330dce47 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -11,8 +11,8 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" @@ -23,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" using namespace mlir; using namespace mlir::torch; @@ -43,8 +42,9 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - target.addLegalDialect(); + target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -54,12 +54,14 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, target); - torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns, - target); + torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, + patterns, target); torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, target); torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, patterns, target); + torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, + target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToMhlo/pooling.mlir new file mode 100644 index 0000000000000..ab057522ba2fd --- /dev/null +++ b/test/Conversion/TorchToMhlo/pooling.mlir @@ -0,0 +1,218 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %false = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor) -> () +// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d$padding( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.max_pool2d %arg0, %0, %1, %2, %2, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d_with_indices( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %false = torch.constant.bool false +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor +// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i64 +// CHECK: %[[VAL_13:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]] : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i64 +// CHECK: %[[VAL_15:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_14]] : tensor<2xi64> +// CHECK: %[[VAL_16:.*]] = "mhlo.dynamic_iota"(%[[VAL_15]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_16]], %[[VAL_13]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor, %[[IVAL_2:.*]]: tensor, %[[IVAL_3:.*]]: tensor): +// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor +// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor, tensor, tensor) -> tensor +// CHECK: "mhlo.return"(%[[IVAL_5]], %[[IVAL_9]]) : (tensor, tensor) -> () +// CHECK{LITERAL}: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> +// CHECK: return %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> +func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %false = torch.constant.bool false + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> + return %result0, %result1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool2d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): +// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor) -> () +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64 +// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 +// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): +// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor +// CHECK: "mhlo.return"(%[[IVAL_5]]) : (tensor) -> () +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool2d$count_include_pad( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): +// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor) -> () +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<9> : tensor +// CHECK: %[[VAL_8:.*]] = mhlo.convert(%[[VAL_7]]) : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = chlo.broadcast_divide %[[VAL_6]], %[[VAL_8]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %true = torch.constant.bool true + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} \ No newline at end of file