Skip to content

aten::split converter #2232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import dynamo_tensorrt_converter
from .converter_utils import dynamic_unsupported_with_args

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -354,6 +355,34 @@ def aten_ops_softmax(
)


@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic_unsupported_with_args import may also have been removed in the rebase

)
@dynamo_tensorrt_converter(
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
)
@dynamo_tensorrt_converter(
torch.ops.aten.split_with_sizes.default,
capability_validator=dynamic_unsupported_with_args([1]),
)
def aten_ops_split(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.split.split(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
split_size_or_sections=args[1],
dim=args_bounds_check(args, 2, 0),
)


@dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc]
def aten_ops_where(
network: TRTNetwork,
Expand Down
52 changes: 36 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import functools
import logging
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
Expand Down Expand Up @@ -60,34 +61,53 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:


def dynamic_unsupported(node: torch.fx.Node) -> bool:
"""Validates that a node has no dynamic args, kwargs, or outputs"""
return _dynamic_unsupported(node=node)


def dynamic_unsupported_with_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns a validator that a node has no dynamic args at specific positions"""
return functools.partial(
_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check
)


def _dynamic_unsupported(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
return getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))

# Check node value itself
if ("val" in node.meta) and getattr(
node.meta["val"], "_has_symbolic_sizes_strides", False
):
if arg_positions_to_check is None and _is_subnode_dynamic(node):
return False

# Check node arguments individually
if any(
(
("val" in arg.meta)
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
)
for arg in node.args
if isinstance(arg, torch.fx.Node)
if arg_positions_to_check is None and any(
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
):
return False
# Check specific arg positions if the caller has specified positions to check
elif arg_positions_to_check is not None and any(
_is_subnode_dynamic(node.args[i])
for i in arg_positions_to_check
if isinstance(node.args[i], torch.fx.Node)
):
return False

# Check node keyword arguments individually
if any(
(
("val" in kwarg.meta)
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
)
if arg_positions_to_check is None and any(
_is_subnode_dynamic(kwarg)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
select,
shape,
slice,
split,
squeeze,
unary,
unsqueeze,
Expand Down
81 changes: 81 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch_tensorrt as trt
from torch import Tensor
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def split(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
split_size_or_sections: Union[int, List[int]],
dim: int = 0,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"split received input {input} that is not part " "of the TensorRT region!"
)

dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape > 0:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"

split_sizes = []
if isinstance(split_size_or_sections, int):
split_sizes.append(split_size_or_sections)
else:
for split_size_or_section in split_size_or_sections:
split_sizes.append(split_size_or_section)

start = [0] * len(input.shape)
stride = [1] * len(start)
offset = 0
if len(split_sizes) == 1:
num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0]
split_sizes = [split_sizes[0]] * num_splits
else:
num_splits = len(split_sizes)
sum_split_sizes = sum(split_sizes)
if sum_split_sizes != input.shape[dim]:
raise RuntimeError(
f"split sizes don't add up to the tensor's size in the given dimension"
)

if num_splits < 1:
raise RuntimeError(
f"Invalid split: {input.shape[dim]} with split_size={split_sizes}"
)

max_offset = input.shape[dim]
# add slice layers
output = []
for i in range(num_splits):
shape = list(input.shape)
shape[dim] = min(split_sizes[i], max_offset - offset)
start[dim] = offset
if dynamic_shape:
shape = get_shape_with_dynamic_shape(
network, target, source_ir, f"{name}_shape_{i}", shape, input
)
layer = network.add_slice(
input, start=start, shape=[] if dynamic_shape else shape, stride=stride
)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_sizes[i]
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
return output
175 changes: 175 additions & 0 deletions tests/py/dynamo/conversion/test_split_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
from .harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException


# FIXME: check about implicit and explicit batch
class TestSplitConverterNoDim(DispatchTestCase):
@parameterized.expand(
[
("split_size_or_sections_no_dim", 2),
]
)
def test_split(self, _, split_size_or_tensor):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor)
return out

input = [torch.randn(10).reshape(5, 2)]
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split.Tensor},
disable_passes=True,
)

@parameterized.expand(
[
("split_size_or_sections_list_no_dim_list", [1, 4]),
]
)
def test_split_list(self, _, split_size_or_tensor):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor)
return out

input = [torch.randn(10).reshape(5, 2)]
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split_with_sizes.default},
disable_passes=True,
)

@parameterized.expand(
[
("split_size_or_sections_dims", 2, 1),
]
)
def test_split(self, _, split_size_or_tensor, dim):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor, dim)
return out

input = [torch.randn(10).reshape(5, 2)]
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split.Tensor},
disable_passes=True,
)

@parameterized.expand(
[
("split_size_or_sections_list_dims", [1, 1], 1),
]
)
def test_split_dim_list(self, _, split_size_or_tensor, dim):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor, dim)
return out

input = [torch.randn(10).reshape(5, 2)]
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split_with_sizes.default},
disable_passes=True,
)

@parameterized.expand(
[
("split_size_or_sections_list_dims_not_full_list", [1, 1], 1),
]
)
def test_split_dim_list(self, _, split_size_or_tensor, dim):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor, dim)
return out

input = [torch.randn(15).reshape(5, 3)]
with self.assertRaises(RuntimeError):
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split_with_sizes.default},
disable_passes=True,
)

@parameterized.expand(
[
("select_split_size_or_sections_dim_dynamic_shape", 2, 1),
]
)
def test_split_dynamic(self, _, split_size_or_tensor, dim):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.split(input, split_size_or_tensor, dim)
return out

input_specs = [
Input(
shape=(1, 10, -1),
dtype=torch.float32,
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
expected_ops={torch.ops.aten.split.Tensor},
disable_passes=True,
)

@parameterized.expand(
[
("select_chunk_dim", 6, 0),
]
)
def test_split_dynamic(self, _, chunk, dim):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.ops.aten.chunk(input, chunk, dim)
return out

input = [torch.randn(11)]
with self.assertRaises(UnsupportedOperatorException):
self.run_test(
TestModule(),
input,
expected_ops={torch.ops.aten.split.Tensor},
disable_passes=True,
)


if __name__ == "__main__":
run_tests()