Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add support for split #5174

Merged
merged 3 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ def _impl(inputs, input_types):
return _op.transform.strided_slice(data, begin, end, strides)
return _impl

def _split():
def _impl(inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
dim = int(inputs[2])

now_indice = split_size
indices = []
while now_indice < _infer_shape(data)[dim]:
indices.append(now_indice)
now_indice += split_size

return _op.split(data, indices, dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this op is splitting the index evenly, you can just provide a number of sections to split to rather than a list of indices.

return _op.split(data, int(_infer_shape(data)[dim] / split_size), dim)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split op in Relay and PyTorch have different behaviors. In Relay, if indices_or_sections is an integer, the input will be divided equally along the given axis. But in Pytorch, if split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). The last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

return _impl

def _split_with_sizes():
def _impl(inputs, inputs_types):
data = inputs[0]
dim = int(inputs[2])

now_indice = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better name for this variable would be split_index.

indices = []
sections = _infer_shape(inputs[1])
for i in range(len(sections) - 1):
now_indice += sections[i]
indices.append(now_indice)

return _op.split(data, indices, dim)
return _impl

def _select():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -886,6 +916,8 @@ def _wrap_const(c):
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
Expand Down Expand Up @@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)

if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])

func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])

return _module.IRModule.from_expr(func), tvm_params
22 changes: 22 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,27 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)

def test_forward_split():
torch.set_grad_enabled(False)
input_shape = [4, 10]

class Split1(Module):
def forward(self, *args):
return torch.split(args[0], 2, 0)

class Split2(Module):
def forward(self, *args):
return torch.split(args[0], [2, 3, 5], 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this end up testing split_with_sizes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the Split2 will test split_with_sizes.


class Split3(Module):
def forward(self, *args):
return torch.split(args[0], 3, 1)

input_data = torch.rand(input_shape).float()
verify_model(Split1().float().eval(), input_data=input_data)
verify_model(Split2().float().eval(), input_data=input_data)
verify_model(Split3().float().eval(), input_data=input_data)

def test_forward_avgpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -1077,6 +1098,7 @@ def forward(self, xs):
test_forward_expand()
test_forward_pow()
test_forward_chunk()
test_forward_split()
test_upsample()
test_to()
test_adaptive_pool3d()
Expand Down