Skip to content

Commit

Permalink
[stablehlo] support aten_adaptive_max_pool1d
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Sep 24, 2024
1 parent e4f2bdf commit a07e372
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 1 deletion.
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7078,6 +7078,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
}];
}

def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
158 changes: 157 additions & 1 deletion lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,

// Max pooling
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) {
AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
Expand All @@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr;
}

// AtenMaxPool1dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();

auto outValTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
auto outIdxTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));

if (inputRank <= 1) {
return op.emitError(
"max_pooling1d only supports inputs with rank higher than 1");
}

SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
bool ceilMode = false;

if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}

SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);

std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 1);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 1);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 1);
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[0];

Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
DenseI64ArrayAttr baseDilations;

auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);

// no need to reshape here for max_pool_1d. Need to make sure the iota
// dimension. dim=inputRank-2 or dim=inputRank-1?
auto indexTensor =
rewriter
.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
.getResult();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();

auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);

// add block.
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *firstIdxArg = std::next(firstValArg);
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);

stablehlo::ComparisonTypeAttr compareTypeAttr;
if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}

stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);

Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);

// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);

rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}

rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}

// AtenMaxPool2dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
Expand Down Expand Up @@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
#undef INSERT_ATEN_POOLING_PATTERN
Expand Down
81 changes: 81 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7298,6 +7298,86 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
};
} // namespace

namespace {
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
// op.
class DecomposeAtenAdaptiveMaxPool1dOp
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();

Value input = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);

int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_max_pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}

Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});
Value dialationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});

rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.

Expand Down Expand Up @@ -9801,6 +9881,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit(
Expand Down
14 changes: 14 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,20 @@ def forward(self, x):
def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10))

class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
def __init__(self):
super().__init__()
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)

@export
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
def forward(self, x):
return self.amp1d(x)


@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7))

# AdaptiveMaxPool2d

Expand Down

0 comments on commit a07e372

Please sign in to comment.