Skip to content

Commit 7fa68ba

Browse files
[Relay][AutoTVM] Relay op strategy (apache#4644)
* relay op strategy fix lint bitpack strategy bitserial_dense (apache#6) * update strategy * address comments fix a few topi test Dense strategy (apache#5) * dense * add biforst; remove comments * address comment Refactor x86 conv2d_NCHWc (#4) * Refactor x86 conv2d * Add x86 depthwise_conv2d_NCHWc * Add back topi x86 conv2d_nchw * Merge x86 conv2d_nchw and conv2d_NCHWc * Minor fix for x86 conv2d fix more strategy Add x86 conv2d_NCHWc_int8 strategy (apache#8) * Add x86 conv2d_NCHWc_int8 strategy * Remove contrib_conv2d_nchwc_int8 * Fix generic conv2d_NCHWc for int8 * Fix topi arm_cpu conv2d_NCHWc_int8 update x86 conv2d enable specify relay ops to be tuned for autotvm add cuda conv2d strategy add conv2d strategy for rocm add conv2d strategy for hls add conv2d strategy for arm cpu add conv2d strategy for mali add conv2d strategy for bifrost add conv2d strategy for intel graphics clean up and fix lint remove template keys from autotvm remove 2 in the func name address comments fix * fix bugs * lint * address comments * add name to op implement * Modify topi tests (apache#9) * Add pooling, reorg, softmax and vision * Add lrn * fix topi test * fix more topi test * lint * address comments * x * fix more tests & bugs * Modify more tests (apache#10) * Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn * Minor fix * More minor fix * fix more test * try to update vta using strategy * fix cpptest * x * fix rebase err * Fix two tests (apache#11) * change autotvm log format * lint * minor fix * try fix vta test * fix rebase err * tweak * tmp hack for vta pass * fix tutorial * fix * fix more tutorials * fix vta tutorial * minor * address comments * fix * address comments * fix cpptest * fix docs * change data structure name and api * address comments * lint * fix rebase err * updates * fix winograd test * fix doc * rebase * upgrade tophub version number * fix bug * re-enable vta tsim test after tophub is upgraded * fix vta test to use the correct args so the config can be found in tophub Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
1 parent 2753b8b commit 7fa68ba

14 files changed

+172
-203
lines changed

python/vta/ir_pass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,12 @@ def _do_fold(op):
662662
0, 0,
663663
0, 0, 0))
664664
inner = irb.get()
665-
args = op.body.body.args
666-
res_tensor = op.body.body.func.output(0)
665+
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
666+
body = op.body.body
667+
while isinstance(body, tvm.stmt.IfThenElse):
668+
body = body.then_case
669+
args = body.args
670+
res_tensor = body.func.output(0)
667671
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
668672
inner = tvm.tir.AttrStmt(
669673
[dout, res_tensor], 'buffer_bind_scope',

python/vta/top/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from . import bitpack
2121
from .graphpack import graph_pack
2222
from . import op
23-
from . import vta_conv2d
24-
from . import vta_conv2d_transpose
25-
from . import vta_group_conv2d
26-
from . import vta_dense
23+
from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
24+
from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
25+
from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
26+
from .vta_dense import dense_packed, schedule_dense_packed
2727
from . import util

python/vta/top/bitpack.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
import tvm
2323
from topi import util
2424

25-
from tvm.relay.op.op import register_compute, register_schedule
25+
from tvm.relay.op.op import register_compute, register_injective_schedule
2626
from tvm.relay.op.op import register_pattern, OpPattern
27-
from tvm.relay.op.op import schedule_injective
2827

2928
def bitpack(data, bits, pack_type="int8", name="bitpack"):
3029
"""Packs lowest dimension into format needed by VTA
@@ -86,5 +85,5 @@ def compute_bitpack(attrs, inputs):
8685
bits = 8 // lanes
8786
return bitpack(inputs[0], bits, dtype)
8887

89-
register_schedule("bitpack", schedule_injective)
88+
register_injective_schedule("bitpack")
9089
register_pattern("bitpack", OpPattern.INJECTIVE)

python/vta/top/op.py

Lines changed: 80 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@
2222
import topi
2323

2424
from tvm.relay.op import op as reg
25-
from tvm.relay.op.op import OpPattern
26-
from tvm.relay.op.nn import _nn
25+
from tvm.relay.op import strategy as _strategy
26+
from tvm.relay.op.op import OpPattern, OpStrategy
2727

2828
from .util import is_packed_layout
29+
from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
30+
from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
31+
from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
32+
from .vta_dense import dense_packed, schedule_dense_packed
2933
from ..environment import get_env
3034

3135

3236
# override to force partition at copy
3337
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
3438

35-
36-
@reg.register_compute("clip", level=15)
37-
def compute_clip(attrs, inputs, output_type, target):
39+
# add clip vta strategy
40+
def compute_clip_vta(attrs, inputs, output_type):
3841
""" Clip operator. """
3942
x = inputs[0]
4043
a_min = attrs.a_min
@@ -48,139 +51,79 @@ def compute_clip(attrs, inputs, output_type, target):
4851
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
4952
return [x]
5053

51-
52-
@reg.register_compute("nn.conv2d", level=15)
53-
def compute_conv2d(attrs, inputs, output_type, target):
54-
""" Compute definition of conv2d """
55-
padding = topi.util.get_const_tuple(attrs.padding)
56-
strides = topi.util.get_const_tuple(attrs.strides)
57-
dilation = tuple([int(d) for d in attrs.dilation])
54+
def clip_strategy_vta(attrs, inputs, out_type, target):
55+
strategy = OpStrategy()
56+
strategy.add_implementation(
57+
compute_clip_vta,
58+
_strategy.wrap_topi_schedule(topi.generic.schedule_injective),
59+
name="clip.vta")
60+
return strategy
61+
62+
reg.get("clip").get_attr("FTVMStrategy").register(clip_strategy_vta, "vta")
63+
64+
@_strategy.conv2d_strategy.register("vta")
65+
def conv2d_strategy_vta(attrs, inputs, out_type, target):
66+
"""conv2d vta strategy"""
67+
strategy = OpStrategy()
68+
kernel = inputs[1]
69+
dilation = topi.util.get_const_tuple(attrs.dilation)
5870
groups = attrs.groups
5971
layout = attrs.data_layout
60-
out_dtype = attrs.out_dtype
61-
62-
if target.device_name == "vta":
63-
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
64-
if is_packed_layout(layout):
65-
if groups == 1:
66-
assert groups == 1
67-
env = get_env()
68-
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
69-
assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
70-
inputs = list(inputs)
71-
assert inputs[1].dtype == "int8"
72-
return [topi.nn.conv2d(inputs[0],
73-
inputs[1],
74-
strides,
75-
padding,
76-
dilation,
77-
layout,
78-
out_dtype)]
79-
return [topi.nn.group_conv2d_nchw(inputs[0],
80-
inputs[1],
81-
strides,
82-
padding,
83-
dilation,
84-
groups,
85-
out_dtype)]
86-
# If it's not packed, run on ARM CPU
87-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
88-
return _nn.compute_conv2d(attrs, inputs, output_type, target)
89-
90-
# If VTA is not the target, default to _nn def
91-
return _nn.compute_conv2d(attrs, inputs, output_type, target)
92-
93-
94-
@reg.register_schedule("nn.conv2d", level=15)
95-
def schedule_conv2d(attrs, outs, target):
96-
""" Schedule definition of conv2d """
97-
groups = attrs.groups
98-
layout = attrs.data_layout
99-
100-
if target.device_name == "vta":
101-
if is_packed_layout(layout):
102-
target = tvm.target.create(target)
103-
assert target.device_name == "vta"
104-
if groups == 1:
105-
return topi.generic.schedule_conv2d_nchw(outs)
106-
return topi.generic.schedule_group_conv2d_nchw(outs)
107-
# If it's not packed, run on ARM CPU
108-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
109-
return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current())
110-
111-
# If VTA is not the target, default to _nn def
112-
return _nn.schedule_conv2d(attrs, outs, target)
113-
114-
115-
@reg.register_compute("nn.conv2d_transpose", level=15)
116-
def compute_conv2d_transpose(attrs, inputs, output_type, target):
117-
""" 2D convolution algorithm.
118-
"""
119-
padding = topi.util.get_const_tuple(attrs.padding)
120-
strides = topi.util.get_const_tuple(attrs.strides)
121-
dilation = tuple([int(d) for d in attrs.dilation])
122-
layout = attrs.data_layout
123-
out_dtype = attrs.out_dtype
124-
125-
if target.device_name == "vta":
126-
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
127-
if is_packed_layout(layout):
128-
return [topi.nn.conv2d_transpose_nchw(
129-
inputs[0], inputs[1], strides, padding, out_dtype)]
130-
# If it's not packed, run on ARM CPU
131-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
132-
return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
133-
134-
# If VTA is not the target, default to _nn def
135-
return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
13672

137-
138-
@reg.register_schedule("nn.conv2d_transpose", level=15)
139-
def schedule_conv2d_transpose(attrs, outputs, target):
140-
""" 2D convolution schedule.
141-
"""
73+
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
74+
if is_packed_layout(layout):
75+
if groups == 1:
76+
env = get_env()
77+
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
78+
assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
79+
assert kernel.dtype == "int8"
80+
81+
strategy.add_implementation(
82+
_strategy.wrap_compute_conv2d(conv2d_packed, True),
83+
_strategy.wrap_topi_schedule(schedule_conv2d_packed),
84+
name="conv2d_packed.vta")
85+
else: # group_conv2d
86+
strategy.add_implementation(
87+
_strategy.wrap_compute_conv2d(group_conv2d_packed, has_groups=True),
88+
_strategy.wrap_topi_schedule(schedule_group_conv2d_packed),
89+
name="group_conv2d_packed.vta")
90+
return strategy
91+
92+
# If it's not packed, run on ARM CPU
93+
arm_tgt = tvm.target.arm_cpu(target.model)
94+
return _strategy.arm_cpu.conv2d_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
95+
96+
97+
@_strategy.conv2d_transpose_strategy.register("vta")
98+
def conv2d_transpose_strategy_vta(attrs, inputs, out_type, target):
99+
"""conv2d_transpose vta strategy"""
100+
dilation = topi.util.get_const_tuple(attrs.dilation)
142101
layout = attrs.data_layout
143-
144-
if target.device_name == "vta":
145-
if is_packed_layout(layout):
146-
return topi.nn.schedule_conv2d_transpose_nchw(outputs)
147-
# If it's not packed, run on ARM CPU
148-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
149-
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
150-
151-
# If VTA is not the target, default to _nn def
152-
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
153-
154-
155-
@reg.register_compute("nn.dense", level=15)
156-
def compute_dense(attrs, inputs, out_type, target):
157-
"""Compute definition of dense"""
158-
out_dtype = attrs.out_dtype
159-
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
160-
161-
if target.device_name == "vta":
162-
if inputs[0].shape == 4: # this implies the layout is packed
163-
target = tvm.target.create(target)
164-
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
165-
# If it's not packed, run on ARM CPU
166-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
167-
return _nn.compute_dense(attrs, inputs, out_type, target)
168-
169-
# If VTA is not the target, default to _nn def
170-
return _nn.compute_dense(attrs, inputs, out_type, target)
171-
172-
173-
@reg.register_schedule("nn.dense", level=15)
174-
def schedule_dense(attrs, outs, target):
175-
"""Schedule definition of dense"""
176-
if target.device_name == "vta":
177-
if outs[0].shape == 4: # this implies the layout is packed
178-
target = tvm.target.create(target)
179-
assert target.device_name == "vta"
180-
return topi.generic.schedule_dense(outs)
181-
# If it's not packed, run on ARM CPU
182-
with tvm.target.arm_cpu(tvm.target.Target.current().model):
183-
return _nn.schedule_dense(attrs, outs, tvm.target.Target.current())
184-
185-
# If VTA is not the target, default to _nn def
186-
return _nn.schedule_dense(attrs, outs, target)
102+
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
103+
104+
if is_packed_layout(layout):
105+
strategy = OpStrategy()
106+
strategy.add_implementation(
107+
_strategy.wrap_compute_conv2d_transpose(conv2d_transpose_packed),
108+
_strategy.wrap_topi_schedule(schedule_conv2d_transpose_packed),
109+
name="conv2d_transpose_packed.vta")
110+
return strategy
111+
112+
# If it's not packed, run on ARM CPU
113+
arm_tgt = tvm.target.arm_cpu(target.model)
114+
return _strategy.arm_cpu.conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
115+
116+
117+
@_strategy.dense_strategy.register("vta")
118+
def dense_strategy_vta(attrs, inputs, out_type, target):
119+
"""dense vta strategy"""
120+
if inputs[0].shape == 4: # this implies the layout is packed
121+
strategy = OpStrategy()
122+
strategy.add_implementation(
123+
_strategy.wrap_compute_dense(dense_packed),
124+
_strategy.wrap_topi_schedule(schedule_dense_packed),
125+
name="dense_packed.vta")
126+
return strategy
127+
# If it's not packed, run on ARM CPU
128+
arm_tgt = tvm.target.arm_cpu(target.model)
129+
return _strategy.x86.dense_strategy_cpu(attrs, inputs, out_type, arm_tgt)

python/vta/top/vta_conv2d.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,8 @@
2525
from .util import is_packed_layout
2626
from ..environment import get_env
2727

28-
@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct')
29-
def _declaration_conv2d(cfg,
30-
data,
31-
kernel,
32-
strides,
33-
padding,
34-
dilation,
35-
layout,
36-
out_dtype):
28+
@autotvm.register_topi_compute("conv2d_packed.vta")
29+
def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
3730
""" Packed conv2d function."""
3831
if not is_packed_layout(layout):
3932
raise topi.InvalidShapeError()
@@ -69,8 +62,9 @@ def _declaration_conv2d(cfg,
6962

7063
return res
7164

72-
@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct')
73-
def _schedule_conv2d(cfg, outs):
65+
@autotvm.register_topi_schedule("conv2d_packed.vta")
66+
def schedule_conv2d_packed(cfg, outs):
67+
"""Schedule packed conv2d"""
7468
assert len(outs) == 1
7569
output = outs[0]
7670
const_ops = []

python/vta/top/vta_conv2d_transpose.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,9 @@
2626

2727
from ..environment import get_env
2828

29-
@autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct')
30-
def _declatation_conv2d_transpose(cfg,
31-
data,
32-
kernel,
33-
strides,
34-
padding,
35-
out_dtype):
29+
@autotvm.register_topi_compute("conv2d_transpose_packed.vta")
30+
def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype):
31+
"""Packed conv2d_transpose compute"""
3632
ishape = get_const_tuple(data.shape)
3733
kshape = get_const_tuple(kernel.shape)
3834
b, c_i, i_h, i_w, t_b, t_ci = ishape
@@ -75,8 +71,9 @@ def _declatation_conv2d_transpose(cfg,
7571

7672
return out
7773

78-
@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw, 'vta', 'direct')
79-
def _schedule_conv2d_transpose(cfg, outs):
74+
@autotvm.register_topi_schedule("conv2d_transpose_packed.vta")
75+
def schedule_conv2d_transpose_packed(cfg, outs):
76+
"""Schedule packed conv2d_transpose"""
8077
assert len(outs) == 1
8178
output = outs[0]
8279
ewise_inputs = []

python/vta/top/vta_dense.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,8 @@ def is_packed_layout(layout):
3232
return True
3333
return False
3434

35-
@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct')
36-
def _declaration_dense(cfg,
37-
data,
38-
weight,
39-
bias=None,
40-
out_dtype=None):
35+
@autotvm.register_topi_compute("dense_packed.vta")
36+
def dense_packed(cfg, data, weight, bias=None, out_dtype=None):
4137
"""Dense function declaration."""
4238

4339
# Make sure that the dense operator is packed
@@ -67,8 +63,8 @@ def _declaration_dense(cfg,
6763

6864
return res
6965

70-
@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct')
71-
def _schedule_dense(cfg, outs):
66+
@autotvm.register_topi_schedule("dense_packed.vta")
67+
def schedule_dense_packed(cfg, outs):
7268
"""Packed dense schedule."""
7369

7470
assert len(outs) == 1

python/vta/top/vta_group_conv2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
from ..environment import get_env
2626

27-
@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct')
28-
def packed_group_conv2d(cfg,
27+
@autotvm.register_topi_compute("group_conv2d_packed.vta")
28+
def group_conv2d_packed(cfg,
2929
data,
3030
kernel,
3131
strides,
@@ -74,8 +74,8 @@ def packed_group_conv2d(cfg,
7474
return out
7575

7676

77-
@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct')
78-
def schedule_packed_group_conv2d(cfg, outs):
77+
@autotvm.register_topi_schedule("group_conv2d_packed.vta")
78+
def schedule_group_conv2d_packed(cfg, outs):
7979
""" Schedule the packed conv2d.
8080
"""
8181
assert len(outs) == 1

0 commit comments

Comments
 (0)