-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Add support for split #5174
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,36 @@ def _impl(inputs, input_types): | |
return _op.transform.strided_slice(data, begin, end, strides) | ||
return _impl | ||
|
||
def _split(): | ||
def _impl(inputs, input_types): | ||
data = inputs[0] | ||
split_size = int(inputs[1]) | ||
dim = int(inputs[2]) | ||
|
||
now_indice = split_size | ||
indices = [] | ||
while now_indice < _infer_shape(data)[dim]: | ||
indices.append(now_indice) | ||
now_indice += split_size | ||
|
||
return _op.split(data, indices, dim) | ||
return _impl | ||
|
||
def _split_with_sizes(): | ||
def _impl(inputs, inputs_types): | ||
data = inputs[0] | ||
dim = int(inputs[2]) | ||
|
||
now_indice = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better name for this variable would be |
||
indices = [] | ||
sections = _infer_shape(inputs[1]) | ||
for i in range(len(sections) - 1): | ||
now_indice += sections[i] | ||
indices.append(now_indice) | ||
|
||
return _op.split(data, indices, dim) | ||
return _impl | ||
|
||
def _select(): | ||
def _impl(inputs, input_types): | ||
data = inputs[0] | ||
|
@@ -886,6 +916,8 @@ def _wrap_const(c): | |
"aten::unsqueeze" : _unsqueeze(), | ||
"aten::cat" : _concatenate(), | ||
"aten::slice" : _slice(), | ||
"aten::split" : _split(), | ||
"aten::split_with_sizes" : _split_with_sizes(), | ||
"aten::select" : _select(), | ||
"aten::relu" : _relu(), | ||
"aten::relu_" : _relu(), | ||
|
@@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): | |
|
||
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, | ||
output_index_map, ret_name) | ||
|
||
if isinstance(ret[0], list): | ||
ret[0] = _expr.Tuple(ret[0]) | ||
|
||
func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) | ||
|
||
return _module.IRModule.from_expr(func), tvm_params |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -379,6 +379,27 @@ def test_forward_maxpool1d(): | |
stride=2).eval(), | ||
input_data) | ||
|
||
def test_forward_split(): | ||
torch.set_grad_enabled(False) | ||
input_shape = [4, 10] | ||
|
||
class Split1(Module): | ||
def forward(self, *args): | ||
return torch.split(args[0], 2, 0) | ||
|
||
class Split2(Module): | ||
def forward(self, *args): | ||
return torch.split(args[0], [2, 3, 5], 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this end up testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the |
||
|
||
class Split3(Module): | ||
def forward(self, *args): | ||
return torch.split(args[0], 3, 1) | ||
|
||
input_data = torch.rand(input_shape).float() | ||
verify_model(Split1().float().eval(), input_data=input_data) | ||
verify_model(Split2().float().eval(), input_data=input_data) | ||
verify_model(Split3().float().eval(), input_data=input_data) | ||
|
||
def test_forward_avgpool(): | ||
torch.set_grad_enabled(False) | ||
input_shape = [1, 3, 10, 10] | ||
|
@@ -1077,6 +1098,7 @@ def forward(self, xs): | |
test_forward_expand() | ||
test_forward_pow() | ||
test_forward_chunk() | ||
test_forward_split() | ||
test_upsample() | ||
test_to() | ||
test_adaptive_pool3d() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this op is splitting the index evenly, you can just provide a number of sections to split to rather than a list of indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The split op in Relay and PyTorch have different behaviors. In Relay, if indices_or_sections is an integer, the input will be divided equally along the given axis. But in Pytorch, if split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). The last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.