Skip to content

Commit b39bd83

Browse files
authored
[Topi, ARM] Disbale Winograd for quantized tensors. (#5363)
* [Topi, ARM] Disbale Winograd for quantized tensors. * Relaxing float
1 parent 5ce2c29 commit b39bd83

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
5959
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
6060
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
6161
name="conv2d_nchw_spatial_pack.arm_cpu")
62+
6263
# Intel x86 conv2d schedule.
6364
strategy.add_implementation(
6465
wrap_compute_conv2d(topi.x86.conv2d_nchw),
6566
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
6667
name="conv2d_nchw.x86")
68+
6769
# check if winograd algorithm is applicable
6870
_, _, kh, kw = get_const_tuple(kernel.shape)
6971
pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
70-
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
71-
dilation_h == 1 and dilation_w == 1:
72+
is_winograd_applicable = "float" in data.dtype and \
73+
"float" in kernel.dtype and \
74+
kh == 3 and kw == 3 and \
75+
stride_h == 1 and stride_w == 1 and \
76+
dilation_h == 1 and dilation_w == 1
77+
if is_winograd_applicable:
7278
strategy.add_implementation(
7379
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
7480
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),

0 commit comments

Comments
 (0)