Skip to content

Commit

Permalink
[torch] Add support for torch.split_with_sizes via decompose (llvm#…
Browse files Browse the repository at this point in the history
…2979)

Convert to individiual slices and tuple together as a list.

---------

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
  • Loading branch information
rsuderman and ScottTodd authored Mar 5, 2024
1 parent 933db87 commit bc05276
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 10 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12910,6 +12910,30 @@ def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [
}];
}

def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::split.sizes : (Tensor, int[], int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$split_size,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSplitSizesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenSplitSizesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
AllowsTypeRefinement,
ReadOnly
Expand Down
128 changes: 128 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,131 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
};
} // namespace

namespace {
class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimTolistOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto self = op.getOperands()[0];
auto selfTy = dyn_cast<Torch::BaseTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "Unknown self shape");

int64_t rank = selfTy.getSizes().size();
if (rank != 1)
return rewriter.notifyMatchFailure(op, "Expected rank-1");

int64_t length = selfTy.getSizes().back();
if (length == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(op, "Tolist length is unknown");

auto resultTy = dyn_cast<Torch::ListType>(op.getType(0));
if (!resultTy)
return rewriter.notifyMatchFailure(op, "Result type is not list");

auto scalarTy = resultTy.getContainedType();
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto extractTy = rewriter.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1}, selfTy.getOptionalDtype());
llvm::SmallVector<Value> results;
llvm::SmallVector<int64_t> sizes(selfTy.getSizes());
for (int64_t i = 0; i < length; ++i) {
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value extract = rewriter.create<AtenSelectIntOp>(
loc, extractTy, self, /*dim=*/zero, /*index=*/iv);
Value scalar = rewriter.create<AtenItemOp>(loc, scalarTy, extract);
results.push_back(scalar);
}

rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, results);
return failure();
}
};
} // namespace

namespace {
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
return success();
}
};
} // namespace

namespace {
class DecomposeAtenSplitWithSizesOp
: public OpRewritePattern<AtenSplitWithSizesOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSplitWithSizesOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value self = op.getSelf();
SmallVector<Value> splitSizes;
if (!getListConstructElements(op.getSplitSizes(), splitSizes))
return rewriter.notifyMatchFailure(op, "Unable to get sizes");

if (splitSizes.empty())
return rewriter.notifyMatchFailure(op, "No split sizes");

auto selfTy = dyn_cast<BaseTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "Self shape unknown");

int64_t rank = selfTy.getSizes().size();
auto resultTy = dyn_cast<Torch::ListType>(op.getResult().getType());
if (!resultTy)
return rewriter.notifyMatchFailure(op, "Result type not a list");

auto sliceTy =
dyn_cast_or_null<Torch::BaseTensorType>(resultTy.getContainedType());
if (!isa<Torch::BaseTensorType>(sliceTy))
return rewriter.notifyMatchFailure(op, "Slice type is unknown");

int64_t dimInt = 0;
bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt));
if (dimInt < 0)
dimInt += rank;

auto intTy = rewriter.getType<Torch::IntType>();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value begin =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));

llvm::SmallVector<Value> slices;
llvm::SmallVector<int64_t> sliceSizes(sliceTy.getSizes());
int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt];
for (auto size : splitSizes) {
Value end = rewriter.create<AtenAddIntOp>(loc, intTy, begin, size);

int64_t sizeInt;
if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) {
sliceSizes[dimInt] = sizeInt;
} else if (hasDim) {
sliceSizes[dimInt] = defaultLength;
}

sliceTy = rewriter.getType<ValueTensorType>(sliceSizes,
sliceTy.getOptionalDtype());
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, op.getSelf(),
/*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one);
slices.push_back(slice);
begin = end;
}

rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, slices);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenNarrowOp : public OpRewritePattern<AtenNarrowOp> {
public:
Expand Down Expand Up @@ -7008,6 +7133,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
Expand Down Expand Up @@ -7035,6 +7162,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
Expand Down
12 changes: 2 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic"
"IscloseStaticModuleTrue_basic",
"SplitWithSizes_Module_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -1478,15 +1479,6 @@
"VarBiasedModule_basic",
"VarMeanBiasedModule_basic",

# Failure - constant int lowering
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",

# Failure - incorrect numerics
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")

Expand Down
22 changes: 22 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,25 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module())
def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 13, 2))

# ==============================================================================

class SplitWithSizes_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, -1, -1], torch.float32, True),
])
def forward(self, x):
split = torch.split(x, [2, 1, 2], dim=0)
return split[0], split[1], split[2]

@register_test_case(module_factory=lambda: SplitWithSizes_Module())
def SplitWithSizes_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2))



0 comments on commit bc05276

Please sign in to comment.