Skip to content

Commit 0b76568

Browse files
Josh Frommylc
authored andcommitted
[Relay][Frontend][Onnx] Enable group_conv1d import through conv2d conversion. (apache#8321)
* Enable group conv1d import through conv2d hack. * remove silly commented out lines.
1 parent 860cacd commit 0b76568

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ class Conv(OnnxOpConverter):
451451
def _impl_v1(cls, inputs, attr, params):
452452
# Use shape of input to determine convolution type.
453453
data = inputs[0]
454+
kernel = inputs[1]
454455
input_shape = infer_shape(data)
455456
ndim = len(input_shape)
456457

@@ -473,13 +474,32 @@ def _impl_v1(cls, inputs, attr, params):
473474
mode=attr["auto_pad"],
474475
)
475476
elif attr["auto_pad"] == "VALID":
476-
attr["pads"] = tuple([0 for i in range(ndim - 2)])
477+
attr["pads"] = [0 for i in range(ndim - 2)]
477478
elif attr["auto_pad"] == "NOTSET":
478479
pass
479480
else:
480481
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
481482
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
482483
attr.pop("auto_pad")
484+
485+
# Check if the requested convolution is a group conv1d, if so convert it to conv2d.
486+
# TODO(jwfromm) Remove once proper group_conv1d is supported.
487+
group_conv1d = False
488+
if dimension_picker("conv")(attr) == "conv1d" and attr.get("group") != 1:
489+
group_conv1d = True
490+
# Expand input from NCW to NCHW
491+
data = _op.expand_dims(data, axis=2)
492+
# Expand kernel from OIW to OIHW
493+
kernel = _op.expand_dims(kernel, axis=2)
494+
# Add new value to kernel_shape, strices, dilation, pads, if needed
495+
attr["kernel_shape"] = [1] + list(attr["kernel_shape"])
496+
if "strides" in attr:
497+
attr["strides"] = [1] + list(attr["strides"])
498+
if "dilations" in attr:
499+
attr["dilations"] = [1] + list(attr["dilations"])
500+
if "pads" in attr:
501+
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]
502+
483503
out = AttrCvt(
484504
op_name=dimension_picker("conv"),
485505
transforms={
@@ -489,7 +509,11 @@ def _impl_v1(cls, inputs, attr, params):
489509
"group": ("groups", 1),
490510
},
491511
custom_check=dimension_constraint(),
492-
)([data, inputs[1]], attr, params)
512+
)([data, kernel], attr, params)
513+
514+
# If this was a group_conv1d, squish output back to NCW.
515+
if group_conv1d:
516+
out = _op.squeeze(out, axis=[2])
493517

494518
use_bias = len(inputs) == 3
495519
if use_bias:

tests/python/frontend/onnx/test_forward.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,6 +2410,7 @@ def verify_conv(
24102410
kernel_shape,
24112411
strides,
24122412
dilations,
2413+
group=1,
24132414
auto_pad="NOTSET",
24142415
unset_pad=False,
24152416
):
@@ -2422,7 +2423,7 @@ def verify_conv(
24222423
# Default values for other attributes:
24232424
strides=strides,
24242425
dilations=dilations,
2425-
# groups=1
2426+
group=group,
24262427
)
24272428
elif padding is None:
24282429
## autopadding with unset default attributes
@@ -2438,6 +2439,7 @@ def verify_conv(
24382439
outputs=["y"],
24392440
# Default values for other attributes:
24402441
auto_pad=auto_pad,
2442+
group=group,
24412443
**kwargs,
24422444
)
24432445
else:
@@ -2449,7 +2451,7 @@ def verify_conv(
24492451
# Default values for other attributes:
24502452
strides=strides,
24512453
dilations=dilations,
2452-
# groups=1
2454+
group=group,
24532455
pads=padding,
24542456
)
24552457

@@ -2559,6 +2561,20 @@ def repeat(N, D):
25592561
repeat(2, D),
25602562
)
25612563

2564+
# TODO(jwfromm): Merge with other tests once group_conv3d is supported.
2565+
for D in [1, 2]:
2566+
# Group Convolution
2567+
verify_conv(
2568+
(1, 8) + repeat(5, D),
2569+
(8, 1) + repeat(3, D),
2570+
(1, 8) + repeat(5, D),
2571+
2 * repeat(1, D),
2572+
repeat(3, D),
2573+
repeat(1, D),
2574+
repeat(1, D),
2575+
group=8,
2576+
)
2577+
25622578

25632579
def verify_convtranspose_with_padding(
25642580
x_shape,

0 commit comments

Comments
 (0)