Skip to content

Commit 947b488

Browse files
committed
[AlterOpLayout][x86] NHWC to NCHWc conv support.
1 parent 76c2392 commit 947b488

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

topi/python/topi/x86/conv2d_alter_op.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
158158
return None
159159
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
160160

161-
if data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
162-
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
163-
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
164-
# Store altered operator's config
165-
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
166-
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
167-
new_workload = autotvm.task.args_to_workload(
168-
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
169-
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
170-
dispatch_ctx.update(target, new_workload, cfg)
171-
172-
if F.__name__ == 'nnvm.symbol':
173-
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
174-
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
161+
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
162+
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
163+
# Store altered operator's config
164+
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
165+
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
166+
new_workload = autotvm.task.args_to_workload(
167+
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
168+
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
169+
dispatch_ctx.update(target, new_workload, cfg)
170+
171+
if F.__name__ == 'nnvm.symbol':
172+
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
173+
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
175174
return None
176175

177176

0 commit comments

Comments
 (0)