Skip to content

Commit

Permalink
[Torch] Add support for split (apache#5174)
Browse files Browse the repository at this point in the history
* [Torch] Add support for split

* fix

* fix test class
  • Loading branch information
wyc-ruiker authored and Trevor Morris committed Apr 16, 2020
1 parent c8770ff commit 735a8e6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
36 changes: 36 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ def _impl(inputs, input_types):
return _op.transform.strided_slice(data, begin, end, strides)
return _impl

def _split():
def _impl(inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
dim = int(inputs[2])

split_index = split_size
indices = []
while split_index < _infer_shape(data)[dim]:
indices.append(split_index)
split_index += split_size

return _op.split(data, indices, dim)
return _impl

def _split_with_sizes():
def _impl(inputs, inputs_types):
data = inputs[0]
dim = int(inputs[2])

split_index = 0
indices = []
sections = _infer_shape(inputs[1])
for i in range(len(sections) - 1):
split_index += sections[i]
indices.append(split_index)

return _op.split(data, indices, dim)
return _impl

def _select():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -886,6 +916,8 @@ def _wrap_const(c):
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
Expand Down Expand Up @@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)

if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])

func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])

return _module.IRModule.from_expr(func), tvm_params
24 changes: 24 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,29 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)

def test_forward_split():
torch.set_grad_enabled(False)
input_shape = [4, 10]

class Split(Module):
def __init__(self, split_size_or_sections, dim):
super(Split, self).__init__()
self.split_size_or_sections = split_size_or_sections
self.dim = dim

def forward(self, *args):
return torch.split(args[0], self.split_size_or_sections, self.dim)

input_data = torch.rand(input_shape).float()
verify_model(Split(2, 0).float().eval(),
input_data=input_data)
verify_model(Split(3, 1).float().eval(),
input_data=input_data)
verify_model(Split(4, 1).float().eval(),
input_data=input_data)
verify_model(Split([2, 3, 5], 1).float().eval(),
input_data=input_data)

def test_forward_avgpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -1077,6 +1100,7 @@ def forward(self, xs):
test_forward_expand()
test_forward_pow()
test_forward_chunk()
test_forward_split()
test_upsample()
test_to()
test_adaptive_pool3d()
Expand Down

0 comments on commit 735a8e6

Please sign in to comment.