Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 10 additions & 65 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def _build(func,
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.

This function collects tuning tasks by building the program
with a "tracing" target and tracing all the calls to topi.
This function is the single program version of extract_from_multiple_program.

Parameters
----------
Expand All @@ -73,66 +72,14 @@ def extract_from_program(func, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
import tvm.relay.op
from tvm import relay
import topi

env = TaskExtractEnv.get()

# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}

topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build,
args=(mod,
target,
target_host,
params))
build_thread.start()
build_thread.join()

logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
warnings.warn("Invalid shape during AutoTVM task creation")
return tasks
return extract_from_multiple_program([func], [params], ops, target, target_host)


def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
""" Extract tuning tasks from multiple relay programs.

This function is the multiple program version of extract_from_program
This function collects tuning tasks by building a list of programs
with a "tracing" target and tracing all the calls to topi.

Parameters
----------
Expand All @@ -152,19 +99,20 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi

env = TaskExtractEnv.get()

# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}

topi_funcs = []
Expand All @@ -185,11 +133,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=my_build,
args=(mod,
target,
target_host,
params))
build_thread = threading.Thread(target=_build,
args=(mod, target, target_host, param))
build_thread.start()
build_thread.join()

Expand Down
21 changes: 16 additions & 5 deletions tests/python/relay/test_autotvm_task_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,47 @@ def get_network(name, batch_size):

def test_task_extraction():
target = 'llvm'
mod_list = []
params_list = []

mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12

mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1

mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
mod_list.append(mod)
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13

mod, params, input_shape = get_network('mobilenet', batch_size=1)
mod, params, _ = get_network('mobilenet', batch_size=1)
mod_list.append(mod)
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20

mod, params, input_shape = get_network('dcgan', batch_size=1)
mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4

tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list,
target=target,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31

if __name__ == '__main__':
test_task_extraction()