Skip to content

Commit 8802b8c

Browse files
author
Anastasia Stulova
committed
[Relay][TOPI] Misc fixes for depthwise conv2d Mali/Bifrost.
- Fix assert for Bifrost. - Set reasonable default axis splits to avoid using tophub. - Fixed typo: arm cpu -> Mali.
1 parent 0b1ff16 commit 8802b8c

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/tvm/relay/op/strategy/bifrost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
8484
name="depthwise_conv2d_nchw.bifrost",
8585
)
8686
elif layout == "NHWC":
87-
assert kernel_layout == "HWIO"
87+
assert kernel_layout == "HWOI"
8888
# For now just reuse general Mali strategy.
8989
strategy.add_implementation(
9090
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),

python/tvm/topi/mali/depthwise_conv2d.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty
3030
return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
3131

3232

33-
# register customized schedule for arm cpu.
33+
# register customized schedule for Mali.
3434
@autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali")
3535
def schedule_depthwise_conv2d_nchw(cfg, outs):
3636
"""Schedule depthwise conv2d
@@ -70,7 +70,7 @@ def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dty
7070
return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
7171

7272

73-
# register customized schedule for arm cpu.
73+
# register customized schedule for Mali.
7474
@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali")
7575
def schedule_depthwise_conv2d_nhwc(cfg, outs):
7676
"""Schedule depthwise conv2d
@@ -124,8 +124,13 @@ def _schedule(cfg, s, pad_data, kernel, conv, layout):
124124

125125
# fallback support
126126
if cfg.is_fallback:
127-
ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali")
128-
cfg.fallback_with_reference_log(ref_log)
127+
if layout == "NCHW":
128+
ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali")
129+
cfg.fallback_with_reference_log(ref_log)
130+
else:
131+
cfg.fallback_split("tile_c", [-1, 4, 2])
132+
cfg.fallback_split("tile_y", [-1, 4, 2])
133+
cfg.fallback_split("tile_x", [-1, 4, 2])
129134
###### space definition end ######
130135

131136
# schedule padding

0 commit comments

Comments
 (0)