Skip to content

[TorchToLinalg] Support AtenReplicationPad1d with lowering to linalg backend #4217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10068,6 +10068,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
}];
}

def Torch_AtenReplicationPad1dOp : Torch_Op<"aten.replication_pad1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::replication_pad1d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenReplicationPad1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenReplicationPad1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
90 changes: 90 additions & 0 deletions lib/Conversion/TorchToLinalg/TensorConstructors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,94 @@ class ConvertAtenConstantPadNdOp

namespace {

class ConvertAtenReplicationPad1dOp
: public OpConversionPattern<AtenReplicationPad1dOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenReplicationPad1dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

if (inputRank < 2)
return rewriter.notifyMatchFailure(op, "input rank must be at least 2");

SmallVector<int64_t> padInts;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");

if (padInts.size() != 2)
return rewriter.notifyMatchFailure(
op, "pad range must have exactly two values");

int64_t leftPad = padInts[0];
int64_t rightPad = padInts[1];

int64_t dimToPad = inputRank - 1;
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);

SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value widthSize = inputShape[dimToPad];
Value widthMinusOne = rewriter.create<arith::SubIOp>(loc, widthSize, one);

// Build offset and size arrays for slicing
SmallVector<OpFoldResult> allOneStrides(inputRank,
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> leftOffsets(inputRank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> rightOffsets(inputRank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
for (int i = 0; i < inputRank; ++i)
sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr(1)
: getAsOpFoldResult(inputShape[i]);

rightOffsets[dimToPad] = getAsOpFoldResult(widthMinusOne);

// Extract leftmost and rightmost slices
Value leftSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, leftOffsets, sizes, allOneStrides);
Value rightSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, rightOffsets, sizes, allOneStrides);

// Create repeated tiles
SmallVector<Value> resultParts;

if (leftPad > 0) {
SmallVector<Value> leftTiles(leftPad, leftSlice);
Value leftConcat =
rewriter.create<tensor::ConcatOp>(loc, dimToPad, leftTiles);
resultParts.push_back(leftConcat);
}

resultParts.push_back(input);

if (rightPad > 0) {
SmallVector<Value> rightTiles(rightPad, rightSlice);
Value rightConcat =
rewriter.create<tensor::ConcatOp>(loc, dimToPad, rightTiles);
resultParts.push_back(rightConcat);
}

Value result =
rewriter.create<tensor::ConcatOp>(loc, dimToPad, resultParts);
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);

return success();
}
};

} // namespace

namespace {

// Lower aten.replication_pad2d operator into a sequence of
// tensor.extract_slice and tensor.concat operations.

Expand Down Expand Up @@ -621,6 +709,8 @@ void mlir::torch::torch_to_linalg::
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReplicationPad2dOp>();
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
target.addIllegalOp<AtenReplicationPad1dOp>();
patterns.add<ConvertAtenReplicationPad1dOp>(typeConverter, context);
target.addIllegalOp<AtenConstantPadNdOp>();
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10805,6 +10805,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: \"\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
Expand All @@ -10831,6 +10856,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
25 changes: 14 additions & 11 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8314,12 +8314,6 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
}
}

// we don't have support for 1-D replicate pad, so pass it as 2d if
// possible.
// TODO: add support for AtenReplicatePad1dOp and remove this.
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
usefulPadIndexEnd = 4;

// make a new list of padding ints if dimensionality reduction can be
// performed
if (usefulPadIndexEnd < padValues.size()) {
Expand Down Expand Up @@ -8357,11 +8351,20 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
}

if (mode == "replicate") {
// only support for replication pad 2d
if (numPadDims != 2)
return failure();
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
switch (numPadDims) {
case 1:
rewriter.replaceOpWithNewOp<AtenReplicationPad1dOp>(
op, op.getType(), op.getSelf(), usefulPads);
break;
case 2:
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
break;
default:
return rewriter.notifyMatchFailure(
op, "unsupported number of dims for 'reflect' mode: " +
std::to_string(numPadDims));
}
return success();
}

Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@
"ReflectionPad3dModuleRight_basic",
"ReflectionPad3dModuleFront_basic",
"ReflectionPad3dModuleBack_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
"ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0",
Expand Down Expand Up @@ -3896,6 +3898,8 @@
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down Expand Up @@ -4725,6 +4729,8 @@
"ReshapeCollapseModule_basic",
"ReshapeDynamicModule_basic",
"ReshapeExpandModule_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
"RollModule_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2232,11 +2232,20 @@ def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False):
def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]:
return pad_shape_fn(self, pad)

def aten〇replication_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]:
assert len(self) >= 2
assert len(padding) == 2, 'padding size expected to be 2'
return pad_shape_fn(self, padding)

def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]:
assert len(self) >= 2
assert len(padding) == 4, 'padding size expected to be 4'
return pad_shape_fn(self, padding)

def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def emit_with_mutating_variants(key, **kwargs):

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)")
emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,52 @@
# ==============================================================================


class ReplicationPad1dModule_3DInput(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.replication_pad1d(x, [3, 5])


@register_test_case(module_factory=lambda: ReplicationPad1dModule_3DInput())
def ReplicationPad1dModule_3DInput_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 15, 20, low=-1))


# ==============================================================================


class ReplicationPad1dModule_2DInput(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.replication_pad1d(x, [2, 3])


@register_test_case(module_factory=lambda: ReplicationPad1dModule_2DInput())
def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils):
module.forward(tu.rand(7, 12, low=-1))


# ==============================================================================


class ReflectionPad2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading