Skip to content

Commit

Permalink
Add default for split op (#9489)
Browse files Browse the repository at this point in the history
* split fix

* add default split test case
  • Loading branch information
anwang2009 authored Nov 11, 2021
1 parent b26ddfe commit 1e09bb2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,9 +1461,8 @@ class Split(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
splits = attr.get("split", None)
if splits is not None:
if splits is not None and len(splits) > 1:
indices = []
attr["indices_or_sections"] = []
index = 0
for i in splits[:-1]:
index += i
Expand Down
3 changes: 3 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,6 +1966,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
# Split a single value to a single value
verify_split([1], [[1]], [1], pass_split=True)
# Test that the default case modifies nothing when split list has length one
verify_split([[1.0, 2.0]], [[1.0, 2.0]], [2], 1)
verify_split([[1.0, 2.0]], [[1.0, 2.0]], [1], 0)


@tvm.testing.parametrize_targets
Expand Down

0 comments on commit 1e09bb2

Please sign in to comment.