Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost. #8584

Merged
merged 3 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tvm/relay/op/strategy/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.bifrost",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
# For now just reuse general Mali strategy.
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nchw.bifrost",
)
else:
raise RuntimeError(
"Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".format(layout)
Expand Down
17 changes: 10 additions & 7 deletions python/tvm/relay/op/strategy/mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,17 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target):
elif layout == "NHWC":
assert kernel_layout == "HWOI"
if not is_auto_scheduler_enabled():
raise RuntimeError(
"depthwise_conv2d NHWC layout is not enabled for mali without auto_scheduler."
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.mali",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
naive_schedule,
name="depthwise_conv2d_nhwc.mali",
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
naive_schedule,
name="depthwise_conv2d_nhwc.mali",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
else: # group_conv2d
Expand Down
200 changes: 136 additions & 64 deletions python/tvm/topi/mali/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty
return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)


# register customized schedule for arm cpu.
# register customized schedule for Mali.
@autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali")
def schedule_depthwise_conv2d_nchw(cfg, outs):
"""Schedule depthwise conv2d
Expand All @@ -51,86 +51,158 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(pad_data, kernel, conv):
"""schedule depthwise_conv2d"""
max_unroll = 16
vec_size = [1, 2, 4, 8, 16]
def _callback(op):
"""traverse to find op to schedule"""
# schedule depthwise_conv2d
if op.tag == "depthwise_conv2d_nchw":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
conv = op.output(0)
_schedule(cfg, s, pad_data, kernel, conv, "NCHW")

##### space definition begin #####
n, c, y, x = s[conv].op.axis
bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
by, ty, yi = cfg.define_split("tile_y", y, num_outputs=3)
bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3)
cfg.define_annotate("ann_spatial", [ci, yi, xi], policy="try_unroll_vec")
traverse_inline(s, outs[0].op, _callback)
return s

# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
"mali", "rk3399", "depthwise_conv2d_nchw.mali"
)
cfg.fallback_with_reference_log(ref_log)
###### space definition end ######

# schedule padding
n, c, y, x = s[pad_data].op.axis
tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1)
# register original implementation of depthwise_conv2d_nhwc since we don't need to change this part
@autotvm.register_topi_compute("depthwise_conv2d_nhwc.mali")
def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)

# schedule dilation
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

# schedule conv
if conv.op not in s.outputs:
s[conv].set_scope("local")
OL = conv
output = s.outputs[0].output(0)
else:
OL = s.cache_write(conv, "local")
output = conv

n, c, y, x = s[output].op.axis
bc, tc, ci = cfg["tile_c"].apply(s, output, c)
by, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, tx, xi = cfg["tile_x"].apply(s, output, x)

bc = s[output].fuse(n, bc)
s[output].bind(bc, te.thread_axis("blockIdx.z"))
s[output].bind(tc, te.thread_axis("threadIdx.z"))
s[output].bind(by, te.thread_axis("blockIdx.y"))
s[output].bind(ty, te.thread_axis("threadIdx.y"))
s[output].bind(bx, te.thread_axis("blockIdx.x"))
s[output].bind(tx, te.thread_axis("threadIdx.x"))

di, dj = s[OL].op.reduce_axis
s[OL].unroll(di)
s[OL].unroll(dj)

s[OL].compute_at(s[output], tx)
n, ci, yi, xi = s[OL].op.axis

cfg["ann_spatial"].apply(
s,
OL,
[ci, yi, xi],
axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]],
max_unroll=max_unroll,
vec_size=vec_size,
cfg=cfg,
)
# register customized schedule for Mali.
@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali")
def schedule_depthwise_conv2d_nhwc(cfg, outs):
"""Schedule depthwise conv2d

Parameters
----------
cfg: ConfigEntity
The configuration of this template
outs: Array of Tensor
The computation graph description of depthwise convolution2d
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
"""traverse to find op to schedule"""
# schedule depthwise_conv2d
if op.tag == "depthwise_conv2d_nchw":
if op.tag == "depthwise_conv2d_nhwc":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
conv = op.output(0)
_schedule(pad_data, kernel, conv)
_schedule(cfg, s, pad_data, kernel, conv, "NHWC")

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule(cfg, s, pad_data, kernel, conv, layout):
"""schedule depthwise_conv2d"""
assert layout in ("NCHW", "NHWC")

max_unroll = 16
vec_size = [1, 2, 4, 8, 16]

##### space definition begin #####
if layout == "NCHW":
n, c, h, w = s[conv].op.axis
else:
n, h, w, c = s[conv].op.axis

bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
bh, th, hi = cfg.define_split("tile_y", h, num_outputs=3)
bw, tw, wi = cfg.define_split("tile_x", w, num_outputs=3)
cfg.define_annotate("ann_spatial", [ci, hi, wi], policy="try_unroll_vec")

# fallback support
if cfg.is_fallback:
if layout == "NCHW":
ref_log = autotvm.tophub.load_reference_log(
"mali", "rk3399", "depthwise_conv2d_nchw.mali"
)
cfg.fallback_with_reference_log(ref_log)
else:
cfg.fallback_split("tile_c", [-1, 4, 2])
cfg.fallback_split("tile_y", [-1, 4, 2])
cfg.fallback_split("tile_x", [-1, 4, 2])
###### space definition end ######

# schedule padding
if layout == "NCHW":
n, c, h, w = s[pad_data].op.axis
z, y, x = c, h, w
z_factor, y_factor, x_factor = cfg["tile_c"].size[1], 1, 1
else:
n, h, w, c = s[pad_data].op.axis
z, y, x = h, w, c
z_factor, y_factor, x_factor = 1, 1, cfg["tile_c"].size[1]
tile_and_bind3d(s, pad_data, z, y, x, z_factor, y_factor, x_factor)

# schedule dilation
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

# schedule conv
if conv.op not in s.outputs:
s[conv].set_scope("local")
OL = conv
output = s.outputs[0].output(0)
else:
OL = s.cache_write(conv, "local")
output = conv

if layout == "NCHW":
n, c, h, w = s[output].op.axis
else:
n, h, w, c = s[output].op.axis

bc, tc, ci = cfg["tile_c"].apply(s, output, c)
bh, th, hi = cfg["tile_y"].apply(s, output, h)
bw, tw, wi = cfg["tile_x"].apply(s, output, w)

if layout == "NCHW":
bz, tz, by, ty, bx, tx = bc, tc, bh, th, bw, tw
else:
bz, tz, by, ty, bx, tx = bh, th, bw, tw, bc, tc

bz = s[output].fuse(n, bz)
s[output].bind(bz, te.thread_axis("blockIdx.z"))
s[output].bind(tz, te.thread_axis("threadIdx.z"))
s[output].bind(by, te.thread_axis("blockIdx.y"))
s[output].bind(ty, te.thread_axis("threadIdx.y"))
s[output].bind(bx, te.thread_axis("blockIdx.x"))
s[output].bind(tx, te.thread_axis("threadIdx.x"))

di, dj = s[OL].op.reduce_axis
s[OL].unroll(di)
s[OL].unroll(dj)

s[OL].compute_at(s[output], tx)

if layout == "NCHW":
n, ci, hi, wi = s[OL].op.axis
else:
n, hi, wi, ci = s[OL].op.axis

cfg["ann_spatial"].apply(
s,
OL,
[ci, hi, wi],
axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]],
max_unroll=max_unroll,
vec_size=vec_size,
cfg=cfg,
)


def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
"""tile and bind 3d"""
y_factor = y_factor or z_factor
Expand Down
2 changes: 2 additions & 0 deletions tests/python/topi/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
)
],
"gpu": [(topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc)],
"mali": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)],
"bifrost": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)],
},
"NCHWc": {
"generic": [(topi.x86.depthwise_conv2d_NCHWc, topi.x86.schedule_depthwise_conv2d_NCHWc)],
Expand Down