@@ -760,17 +760,23 @@ class FullyConnectedConverter
760760 }
761761};
762762
763- class MaxPool2dConverter : public OpRewritePattern <tosa::MaxPool2dOp> {
763+ class MaxPool2dConverter : public OpConversionPattern <tosa::MaxPool2dOp> {
764764public:
765- using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern ;
765+ using OpConversionPattern::OpConversionPattern ;
766766
767- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
768- PatternRewriter &rewriter) const final {
767+ LogicalResult
768+ matchAndRewrite (tosa::MaxPool2dOp op, OpAdaptor adaptor,
769+ ConversionPatternRewriter &rewriter) const final {
769770 Location loc = op.getLoc ();
770- Value input = op .getInput ();
771+ Value input = adaptor .getInput ();
771772 ShapedType inputTy = cast<ShapedType>(input.getType ());
772773
773- ShapedType resultTy = cast<ShapedType>(op.getType ());
774+ bool isUnsigned =
775+ cast<ShapedType>(op.getType ()).getElementType ().isUnsignedInteger ();
776+ ShapedType resultTy =
777+ cast<ShapedType>(getTypeConverter ()->convertType (op.getType ()));
778+ if (!resultTy)
779+ return rewriter.notifyMatchFailure (op, " failed to convert type" );
774780 Type resultETy = inputTy.getElementType ();
775781
776782 auto dynamicDimsOr =
@@ -786,7 +792,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
786792 resultETy, APFloat::getLargest (
787793 cast<FloatType>(resultETy).getFloatSemantics (), true ));
788794
789- if (isa<IntegerType>(resultETy))
795+ else if (isUnsigned)
796+ initialAttr = rewriter.getIntegerAttr (
797+ resultETy, APInt::getZero (resultETy.getIntOrFloatBitWidth ()));
798+ else if (isa<IntegerType>(resultETy))
790799 initialAttr = rewriter.getIntegerAttr (
791800 resultETy,
792801 APInt::getSignedMinValue (resultETy.getIntOrFloatBitWidth ()));
@@ -823,9 +832,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
823832 Value fakeWindowDims =
824833 rewriter.create <tensor::EmptyOp>(loc, kernel, resultETy);
825834
826- rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMaxOp>(
827- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
828- filledEmptyTensor, strideAttr, dilationAttr);
835+ if (isUnsigned) {
836+ rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMaxUnsignedOp>(
837+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
838+ filledEmptyTensor, strideAttr, dilationAttr);
839+ } else {
840+ rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMaxOp>(
841+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
842+ filledEmptyTensor, strideAttr, dilationAttr);
843+ }
829844 return success ();
830845 }
831846};
@@ -1091,7 +1106,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
10911106} // namespace
10921107
10931108void mlir::tosa::populateTosaToLinalgNamedConversionPatterns (
1094- RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1109+ TypeConverter &converter, RewritePatternSet *patterns,
1110+ const TosaToLinalgNamedOptions &options) {
10951111 if (options.preferConv2DKernelLayoutHWCF ) {
10961112 patterns->add <ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
10971113 linalg::Conv2DNhwcHwcfQOp>>(
@@ -1105,11 +1121,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
11051121 // clang-format off
11061122 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
11071123 DepthwiseConvConverter,
1108- MaxPool2dConverter,
11091124 AvgPool2dConverter,
11101125 FullyConnectedConverter,
11111126 TransposeConverter
11121127 >(patterns->getContext ());
1128+ patterns->add <
1129+ MaxPool2dConverter
1130+ >(converter, patterns->getContext ());
11131131 patterns->add <
11141132 MatMulConverter>(patterns->getContext (), options.useMatmulForSingleBatch );
11151133 // clang-format on
0 commit comments