Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of torch.aten.hstack #3563

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13998,6 +13998,29 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
}];
}

def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::hstack : (Tensor[]) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHstackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenHstackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
AllowsTypeRefinement
]> {
Expand Down
52 changes: 52 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10572,6 +10572,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hstack\"(%arg0: !torch.list<list<int>>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %1, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %7 = func.call @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%6) : (!torch.list<int>) -> !torch.list<int>\n"
" %8 = torch.aten.append.t %0, %7 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %2 = torch.aten.__getitem__.t %0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %3 = torch.aten.len.t %2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" } else {\n"
" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" }\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -15070,6 +15095,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hstack\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,58 @@ class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
};
} // namespace

// Decompose `aten.hstack` into `aten.at_least1d` and `aten.cat`.
// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3908
namespace {
class DecomposeAtenHstackOp : public OpRewritePattern<AtenHstackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHstackOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// Get SmallVector<Value> from Value.
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors))
return rewriter.notifyMatchFailure(
op, "unimplemented: the tensor list is not from list construct");

// Execute AtenAtleast1dOp on every tensor inside tensors.
SmallVector<Value> atleast1dTensors;
for (auto tensor : tensors) {
std::optional<unsigned> tensorRank = getTensorRank(tensor);

// Check if the tensor is already of rank >= 1.
if (*tensorRank < 1) {
auto atleast1dTensor =
rewriter.create<AtenAtleast1dOp>(loc, tensor.getType(), tensor);
atleast1dTensors.push_back(atleast1dTensor);
} else {
atleast1dTensors.push_back(tensor);
}
}

// Make Value list from atleast1dTensors variable.
auto elemType = cast<BaseTensorType>(atleast1dTensors[0].getType())
.getWithSizesAndDtype(std::nullopt, nullptr);
Value atleast1dTensorList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(elemType), atleast1dTensors);

// Replace hstack with cat operator.
if (getTensorRank(atleast1dTensors[0]) == 1)
rewriter.replaceOpWithNewOp<AtenCatOp>(
op, op.getType(), atleast1dTensorList,
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)));
else
rewriter.replaceOpWithNewOp<AtenCatOp>(
op, op.getType(), atleast1dTensorList,
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)));

return success();
}
};
} // namespace

// Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
Expand Down Expand Up @@ -9386,6 +9438,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHstackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenHstackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
Expand Down
13 changes: 13 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,10 @@
"GridSamplerBasic4_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticModule_basic",
"IntFloatModule_basic",
Expand Down Expand Up @@ -2134,6 +2138,11 @@
# failed to legalize operation 'torch.aten.rrelu_with_noise'
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
# incompatible return type failure for tosa.concat.
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
# Shape Related failures
"PrimListUnpackNumMismatchModule_basic",
"ReshapeExpandModule_basic",
Expand Down Expand Up @@ -2527,6 +2536,10 @@
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HardtanhBackward_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,19 @@ def aten〇atleast_2d〡shape(self: List[int]) -> List[int]:
def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)


@check_shape_function([
Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case.
])
def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]:

tensors_atleast1d = [aten〇atleast_1d〡shape(tensor) for tensor in tensors]

if len(tensors_atleast1d[0]) == 1:
return upstream_shape_functions.cat(tensors_atleast1d, dim=0)

return upstream_shape_functions.cat(tensors_atleast1d, dim=1)

def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
return self

Expand Down Expand Up @@ -5279,6 +5292,23 @@ def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
NonZeroDTensorWithDtype(torch.complex64)])])
def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int:
ranks: List[Optional[int]] = []
dtypes: List[int] = []
assert len(tensors_rank_dtype) != 0
for tensor_rank_dtype in tensors_rank_dtype:
tensor_rank, tensor_dtype = tensor_rank_dtype
ranks.append(tensor_rank)
dtypes.append(tensor_dtype)

return promote_dtypes(ranks, dtypes)

@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ def emit_with_mutating_variants(key, **kwargs):
has_folder=True,
)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::hstack : (Tensor[]) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
Expand Down
101 changes: 101 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,107 @@ def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils):
# ==============================================================================


class HstackBasicIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 3, 4], torch.bool, True),
([2, 3, 4], torch.int32, True),
([2, 3, 4], torch.int64, True),
]
)
def forward(self, x, y, z):
return torch.ops.aten.hstack([x, y, z])


@register_test_case(module_factory=lambda: HstackBasicIntModule())
def HstackBasicIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(2, 3, 4, low=0, high=2).bool(),
tu.randint(2, 3, 4, low=0, high=100).int(),
tu.randint(2, 3, 4, low=0, high=100).long(),
)


class HstackBasicFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 6, 4], torch.int32, True),
([2, 3, 4], torch.float64, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicFloatModule())
def HstackBasicFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 6, 4).int(),
tu.rand(2, 3, 4).double(),
)


class HstackBasicIntFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int32, True),
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicIntFloatModule())
def HstackBasicIntFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, 6, 4, 2, low=1, high=50).int(),
tu.rand(4, 3, 4, 2),
)


class HstackBasicComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.complex64, True),
([-1, -1, -1, -1], torch.complex128, True),
]
)
def forward(self, x, y):
return torch.ops.aten.hstack([x, y])


@register_test_case(module_factory=lambda: HstackBasicComplexModule())
def HstackBasicComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4, 6, 4, 2).type(torch.complex64),
tu.rand(4, 3, 4, 2).type(torch.complex128),
)


# ==============================================================================


class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading