Skip to content

Commit 5f7f300

Browse files
comaniackevinthesun
authored andcommitted
merge extract_from_program and extract_from_multiple_progam (apache#4173)
1 parent 7e29f18 commit 5f7f300

File tree

2 files changed

+26
-70
lines changed

2 files changed

+26
-70
lines changed

python/tvm/autotvm/task/relay_integration.py

Lines changed: 10 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def _build(func,
5252
def extract_from_program(func, params, ops, target, target_host=None):
5353
""" Extract tuning tasks from a relay program.
5454
55-
This function collects tuning tasks by building the program
56-
with a "tracing" target and tracing all the calls to topi.
55+
This function is the single program version of extract_from_multiple_program.
5756
5857
Parameters
5958
----------
@@ -73,66 +72,14 @@ def extract_from_program(func, params, ops, target, target_host=None):
7372
task: Array of autotvm.task.Task
7473
collected tasks
7574
"""
76-
import tvm.relay.op
77-
from tvm import relay
78-
import topi
79-
80-
env = TaskExtractEnv.get()
81-
82-
# NOTE: To add more ops, you only need to change the following lists
83-
# relay op -> topi compute
84-
OP2TOPI = {
85-
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
86-
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
87-
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
88-
tvm.relay.op.nn.dense: [topi.nn.dense],
89-
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
90-
}
91-
92-
topi_funcs = []
93-
for op_name in ops:
94-
if op_name in OP2TOPI:
95-
topi_funcs.extend(OP2TOPI[op_name])
96-
else:
97-
warnings.warn("Op %s is not tunable, ignored" % op_name)
98-
99-
# run compiler to collect all TOPI calls during compilation
100-
env.reset(topi_funcs)
101-
with env:
102-
# disable logger temporarily
103-
old_state = logger.disabled
104-
logger.disabled = True
105-
106-
relay.backend.compile_engine.get().clear()
107-
# wrap build call in thread to avoid multiprocessing problems
108-
mod = relay.Module.from_expr(func)
109-
build_thread = threading.Thread(target=_build,
110-
args=(mod,
111-
target,
112-
target_host,
113-
params))
114-
build_thread.start()
115-
build_thread.join()
116-
117-
logger.disabled = old_state
118-
119-
# create tasks for target
120-
tasks = []
121-
for task_name, args in env.get_tasks():
122-
try:
123-
tsk = create(task_name, args,
124-
target=target, target_host=target_host,
125-
template_key='direct')
126-
tasks.append(tsk)
127-
except topi.InvalidShapeError:
128-
warnings.warn("Invalid shape during AutoTVM task creation")
129-
return tasks
75+
return extract_from_multiple_program([func], [params], ops, target, target_host)
13076

13177

13278
def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
13379
""" Extract tuning tasks from multiple relay programs.
13480
135-
This function is the multiple program version of extract_from_program
81+
This function collects tuning tasks by building a list of programs
82+
with a "tracing" target and tracing all the calls to topi.
13683
13784
Parameters
13885
----------
@@ -152,19 +99,20 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
15299
task: Array of autotvm.task.Task
153100
collected tasks
154101
"""
155-
env = TaskExtractEnv.get()
156102
import tvm.relay.op
157103
from tvm import relay
158104
import topi
159105

106+
env = TaskExtractEnv.get()
107+
160108
# NOTE: To add more ops, you only need to change the following lists
161109
# relay op -> topi compute
162110
OP2TOPI = {
163111
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
164-
topi.nn.group_conv2d_nchw],
112+
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
165113
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
166114
tvm.relay.op.nn.dense: [topi.nn.dense],
167-
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
115+
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
168116
}
169117

170118
topi_funcs = []
@@ -185,11 +133,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
185133
relay.backend.compile_engine.get().clear()
186134
# wrap build call in thread to avoid multiprocessing problems
187135
mod = relay.Module.from_expr(func)
188-
build_thread = threading.Thread(target=my_build,
189-
args=(mod,
190-
target,
191-
target_host,
192-
params))
136+
build_thread = threading.Thread(target=_build,
137+
args=(mod, target, target_host, param))
193138
build_thread.start()
194139
build_thread.join()
195140

tests/python/relay/test_autotvm_task_extraction.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,36 +37,47 @@ def get_network(name, batch_size):
3737

3838
def test_task_extraction():
3939
target = 'llvm'
40+
mod_list = []
41+
params_list = []
4042

41-
mod, params, input_shape = get_network('resnet-18', batch_size=1)
43+
mod, params, _ = get_network('resnet-18', batch_size=1)
4244
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
4345
params=params,
4446
ops=(relay.op.nn.conv2d,))
4547
assert len(tasks) == 12
4648

47-
mod, params, input_shape = get_network('resnet-18', batch_size=1)
49+
mod, params, _ = get_network('resnet-18', batch_size=1)
4850
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
4951
params=params,
5052
ops=(relay.op.nn.dense,))
5153
assert len(tasks) == 1
5254

53-
mod, params, input_shape = get_network('resnet-18', batch_size=1)
55+
mod, params, _ = get_network('resnet-18', batch_size=1)
56+
mod_list.append(mod)
57+
params_list.append(params)
5458
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
5559
params=params,
5660
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
5761
assert len(tasks) == 13
5862

59-
mod, params, input_shape = get_network('mobilenet', batch_size=1)
63+
mod, params, _ = get_network('mobilenet', batch_size=1)
64+
mod_list.append(mod)
65+
params_list.append(params)
6066
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
6167
params=params,
6268
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
6369
assert len(tasks) == 20
6470

65-
mod, params, input_shape = get_network('dcgan', batch_size=1)
71+
mod, params, _ = get_network('dcgan', batch_size=1)
6672
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
6773
params=params,
6874
ops=(relay.op.nn.conv2d_transpose,))
6975
assert len(tasks) == 4
7076

77+
tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list,
78+
target=target,
79+
ops=(relay.op.nn.conv2d,))
80+
assert len(tasks) == 31
81+
7182
if __name__ == '__main__':
7283
test_task_extraction()

0 commit comments

Comments
 (0)