We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 662bcb4 commit 8994359Copy full SHA for 8994359
python/tvm/relay/frontend/mxnet.py
@@ -932,7 +932,12 @@ def _mx_repeat(inputs, attrs):
932
assert len(inputs) == 1
933
new_attrs = {}
934
new_attrs["repeats"] = attrs.get_int("repeats")
935
- new_attrs["axis"] = attrs.get_int("axis", 0)
+ axis = attrs.get_int("axis", None)
936
+ if axis is None:
937
+ inputs[0] = _op.nn.batch_flatten(inputs[0])
938
+ new_attrs["axis"] = 0
939
+ else:
940
+ new_attrs["axis"] = axis
941
return _op.repeat(inputs[0], **new_attrs)
942
943
0 commit comments