-
Notifications
You must be signed in to change notification settings - Fork 364
aten::split converter #2232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
aten::split converter #2232
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
6cf1d67
aten::split converter
apbose 3f4a2e8
Addressing review comments and checking in tests
apbose e5b7120
feat/fix: Update dynamic unsupported implementation
gs-olive 71bd68f
combining split tests
apbose 4fed932
Removing cast
apbose 797d510
Change in test_split location
apbose e95e0e4
Removing incorrect test and removing cast from split impl
apbose File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
select, | ||
shape, | ||
slice, | ||
split, | ||
squeeze, | ||
unary, | ||
unsqueeze, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch_tensorrt as trt | ||
from torch import Tensor | ||
from torch.fx.node import Target | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape | ||
from torch_tensorrt.fx.converters.converter_utils import ( | ||
has_dynamic_shape, | ||
set_layer_name, | ||
) | ||
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor | ||
|
||
|
||
def split( | ||
network: TRTNetwork, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
split_size_or_sections: Union[int, List[int]], | ||
dim: int = 0, | ||
) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
if not isinstance(input, TRTTensor): | ||
raise RuntimeError( | ||
f"split received input {input} that is not part " "of the TensorRT region!" | ||
) | ||
|
||
dynamic_shape = has_dynamic_shape(input.shape) | ||
if dynamic_shape > 0: | ||
# Check whether slice target dim is dynamic shape dim | ||
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" | ||
|
||
split_sizes = [] | ||
if isinstance(split_size_or_sections, int): | ||
split_sizes.append(split_size_or_sections) | ||
else: | ||
for split_size_or_section in split_size_or_sections: | ||
split_sizes.append(split_size_or_section) | ||
|
||
start = [0] * len(input.shape) | ||
stride = [1] * len(start) | ||
offset = 0 | ||
if len(split_sizes) == 1: | ||
num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0] | ||
split_sizes = [split_sizes[0]] * num_splits | ||
else: | ||
num_splits = len(split_sizes) | ||
sum_split_sizes = sum(split_sizes) | ||
if sum_split_sizes != input.shape[dim]: | ||
raise RuntimeError( | ||
f"split sizes don't add up to the tensor's size in the given dimension" | ||
) | ||
|
||
if num_splits < 1: | ||
raise RuntimeError( | ||
f"Invalid split: {input.shape[dim]} with split_size={split_sizes}" | ||
) | ||
|
||
max_offset = input.shape[dim] | ||
# add slice layers | ||
output = [] | ||
for i in range(num_splits): | ||
shape = list(input.shape) | ||
shape[dim] = min(split_sizes[i], max_offset - offset) | ||
start[dim] = offset | ||
if dynamic_shape: | ||
shape = get_shape_with_dynamic_shape( | ||
network, target, source_ir, f"{name}_shape_{i}", shape, input | ||
) | ||
layer = network.add_slice( | ||
input, start=start, shape=[] if dynamic_shape else shape, stride=stride | ||
) | ||
if dynamic_shape: | ||
layer.set_input(2, shape) | ||
offset += split_sizes[i] | ||
set_layer_name(layer, target, f"{name}_{i}") | ||
output.append(layer.get_output(0)) | ||
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import torch | ||
from .harness import DispatchTestCase | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
from torch_tensorrt import Input | ||
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException | ||
|
||
|
||
# FIXME: check about implicit and explicit batch | ||
class TestSplitConverterNoDim(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
("split_size_or_sections_no_dim", 2), | ||
] | ||
) | ||
def test_split(self, _, split_size_or_tensor): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor) | ||
return out | ||
|
||
input = [torch.randn(10).reshape(5, 2)] | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split.Tensor}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("split_size_or_sections_list_no_dim_list", [1, 4]), | ||
] | ||
) | ||
def test_split_list(self, _, split_size_or_tensor): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor) | ||
return out | ||
|
||
input = [torch.randn(10).reshape(5, 2)] | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split_with_sizes.default}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("split_size_or_sections_dims", 2, 1), | ||
] | ||
) | ||
def test_split(self, _, split_size_or_tensor, dim): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor, dim) | ||
return out | ||
|
||
input = [torch.randn(10).reshape(5, 2)] | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split.Tensor}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("split_size_or_sections_list_dims", [1, 1], 1), | ||
] | ||
) | ||
def test_split_dim_list(self, _, split_size_or_tensor, dim): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor, dim) | ||
return out | ||
|
||
input = [torch.randn(10).reshape(5, 2)] | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split_with_sizes.default}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("split_size_or_sections_list_dims_not_full_list", [1, 1], 1), | ||
] | ||
) | ||
def test_split_dim_list(self, _, split_size_or_tensor, dim): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor, dim) | ||
return out | ||
|
||
input = [torch.randn(15).reshape(5, 3)] | ||
with self.assertRaises(RuntimeError): | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split_with_sizes.default}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("select_split_size_or_sections_dim_dynamic_shape", 2, 1), | ||
] | ||
) | ||
def test_split_dynamic(self, _, split_size_or_tensor, dim): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.split(input, split_size_or_tensor, dim) | ||
return out | ||
|
||
input_specs = [ | ||
Input( | ||
shape=(1, 10, -1), | ||
dtype=torch.float32, | ||
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], | ||
), | ||
] | ||
self.run_test_with_dynamic_shape( | ||
TestModule(), | ||
input_specs, | ||
expected_ops={torch.ops.aten.split.Tensor}, | ||
disable_passes=True, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("select_chunk_dim", 6, 0), | ||
] | ||
) | ||
def test_split_dynamic(self, _, chunk, dim): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
out = torch.ops.aten.chunk(input, chunk, dim) | ||
return out | ||
|
||
input = [torch.randn(11)] | ||
with self.assertRaises(UnsupportedOperatorException): | ||
self.run_test( | ||
TestModule(), | ||
input, | ||
expected_ops={torch.ops.aten.split.Tensor}, | ||
disable_passes=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
dynamic_unsupported_with_args
import may also have been removed in the rebase