@@ -59,16 +59,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
5959 wrap_compute_conv2d (topi .arm_cpu .conv2d_nchw_spatial_pack ),
6060 wrap_topi_schedule (topi .arm_cpu .schedule_conv2d_nchw_spatial_pack ),
6161 name = "conv2d_nchw_spatial_pack.arm_cpu" )
62+
6263 # Intel x86 conv2d schedule.
6364 strategy .add_implementation (
6465 wrap_compute_conv2d (topi .x86 .conv2d_nchw ),
6566 wrap_topi_schedule (topi .x86 .schedule_conv2d_nchw ),
6667 name = "conv2d_nchw.x86" )
68+
6769 # check if winograd algorithm is applicable
6870 _ , _ , kh , kw = get_const_tuple (kernel .shape )
6971 pt , pl , pb , pr = topi .nn .get_pad_tuple (padding , (kh , kw ))
70- if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
71- dilation_h == 1 and dilation_w == 1 :
72+ is_winograd_applicable = "float" in data .dtype and \
73+ "float" in kernel .dtype and \
74+ kh == 3 and kw == 3 and \
75+ stride_h == 1 and stride_w == 1 and \
76+ dilation_h == 1 and dilation_w == 1
77+ if is_winograd_applicable :
7278 strategy .add_implementation (
7379 wrap_compute_conv2d (topi .arm_cpu .conv2d_nchw_winograd ),
7480 wrap_topi_schedule (topi .arm_cpu .schedule_conv2d_nchw_winograd ),
0 commit comments