Skip to content

Commit 8994359

Browse files
authored
[Relay][Bugfix] fix axis parsing of repeat converter in the MXNet frontend (#15891)
fix axis in repeat
1 parent 662bcb4 commit 8994359

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,12 @@ def _mx_repeat(inputs, attrs):
932932
assert len(inputs) == 1
933933
new_attrs = {}
934934
new_attrs["repeats"] = attrs.get_int("repeats")
935-
new_attrs["axis"] = attrs.get_int("axis", 0)
935+
axis = attrs.get_int("axis", None)
936+
if axis is None:
937+
inputs[0] = _op.nn.batch_flatten(inputs[0])
938+
new_attrs["axis"] = 0
939+
else:
940+
new_attrs["axis"] = axis
936941
return _op.repeat(inputs[0], **new_attrs)
937942

938943

0 commit comments

Comments
 (0)