Skip to content

Commit 6b4cb34

Browse files
committed
fix conv2d and conv2d alter op layout for x86
1 parent 69b171c commit 6b4cb34

File tree

18 files changed

+290
-141
lines changed

18 files changed

+290
-141
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ using FTVMStrategy = GenericFunc;
140140
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
141141
Expr(const Attrs& attrs,
142142
const Array<Expr>& args,
143-
const Array<Tensor>& tinfos)>;
143+
const Array<Tensor>& tinfos,
144+
const Type& out_type)>;
144145

145146
/*!
146147
* \brief Convert the layout of operators or replace the

python/tvm/autotvm/graph_tuner/utils/traverse_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def expr2graph(expr, target_ops, node_dict, node_list):
6565
% op_name)
6666
topi_funcs += OP2COMPUTE[op_name]
6767
env.reset(topi_funcs)
68+
# TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
69+
# that # autotvm tasks == # ops. But this won't be true after having relay op
70+
# strategy. We need to find a solution to fix this.
6871
with env:
6972
_expr2graph_impl(expr, target_ops, node_dict, node_list)
7073
task_pos = 0

python/tvm/autotvm/task/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@
3030
FallbackContext, clear_fallback_cache, ApplyGraphBest
3131

3232
from .topi_integration import register_topi_compute, register_topi_schedule, \
33-
TaskExtractEnv, register_topi_compute2, register_topi_schedule2
33+
TaskExtractEnv, register_topi_compute2, register_topi_schedule2, get_workload
3434
from .relay_integration import extract_from_program, extract_from_multiple_program

python/tvm/autotvm/task/dispatcher.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -483,13 +483,14 @@ def _query_inside(self, target, workload):
483483
cfg : ConfigSpace
484484
The specific configuration.
485485
"""
486-
print('=' * 80)
487-
print('query graph dispatcher: %s, %s' % (target, workload))
488486
if self._counter < len(self._records):
489487
cfg = self._records[self._counter][0].config
488+
wkl = self._records[self._counter][0].task.workload
489+
if workload is not None:
490+
assert wkl == workload
490491
self._counter += 1
491-
print(self._counter, cfg)
492-
self.update(target, workload, cfg)
492+
self.update(target, wkl, cfg)
493+
cfg.workload = wkl
493494
return cfg
494495
key = (str(target), workload)
495496
if key not in self._global_cfg_dict:
@@ -504,7 +505,5 @@ def _query_inside(self, target, workload):
504505
return cfg
505506

506507
def update(self, target, workload, cfg):
507-
print('-' * 80)
508-
print('update %s %s -> %s' % (target, workload, cfg))
509508
key = (str(target), workload)
510509
self._global_cfg_dict[key] = cfg

python/tvm/autotvm/task/relay_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
171171

172172
# create tasks for target
173173
tasks = []
174-
for task_name, args, _ in env.get_tasks():
174+
for task_name, args in env.get_tasks():
175175
try:
176176
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
177177
tsk = create(task_name, args,

python/tvm/autotvm/task/topi_integration.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def reset(self, wanted_topi_funcs):
287287
if wanted_topi_funcs is not None:
288288
self.wanted_topi_funcs = wanted_topi_funcs
289289

290-
def add_task(self, task_name, args, cond=None):
290+
def add_task(self, task_name, args):
291291
"""Add AutoTVM task
292292
293293
Parameters
@@ -301,9 +301,7 @@ def add_task(self, task_name, args, cond=None):
301301
cond: SpecializedCondition
302302
Specialized condition to enable the TOPI template.
303303
"""
304-
assert cond is None, \
305-
"AutoTVM currently doesn't support tuning under specialized condition"
306-
key = (task_name, serialize_args(args), None)
304+
key = (task_name, serialize_args(args))
307305
if self.allow_duplicate or key not in self.task_collection:
308306
self.task_collection.append(key)
309307

@@ -515,7 +513,7 @@ def wrapper(*args, **kwargs):
515513
assert not kwargs, "Do not support kwargs in template function call"
516514
task_env = TaskExtractEnv.current
517515
if task_env is not None and task_env.tracing:
518-
task_env.add_task(task_name, args, current_specialization())
516+
task_env.add_task(task_name, args)
519517
workload = args_to_workload2(args, task_name)
520518
tgt = _target.current_target()
521519
cfg = DispatchContext.current.query(tgt, workload)
@@ -548,31 +546,34 @@ def wrapper(*args, **kwargs):
548546
return _decorate(func)
549547
return _decorate
550548

549+
551550
def register_topi_schedule2(task_name, func=None):
552551
def _decorate(topi_schedule):
553552
@register_task_schedule(task_name)
554553
def wrapper(outs, *args, **kwargs):
555-
def traverse(tensors):
556-
"""traverse all ops to find attached workload"""
557-
for t in tensors:
558-
op = t.op
559-
if 'workload' in op.attrs:
560-
return op.attrs['workload']
561-
wkl = traverse(op.input_tensors)
562-
if wkl:
563-
return wkl
564-
return None
565-
566-
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
567-
workload = traverse(outs)
554+
workload = get_workload(outs)
568555
if workload is None:
569556
raise RuntimeError("Cannot find workload in attribute of this schedule")
570557
workload = args_to_workload2(workload)
571558
tgt = _target.current_target()
572559
cfg = DispatchContext.current.query(tgt, workload)
573560
return topi_schedule(cfg, outs, *args, **kwargs)
574-
575561
return wrapper
576562
if func:
577563
return _decorate(func)
578-
return _decorate
564+
return _decorate
565+
566+
567+
def get_workload(outs):
568+
def traverse(tensors):
569+
"""traverse all ops to find attached workload"""
570+
for t in tensors:
571+
op = t.op
572+
if 'workload' in op.attrs:
573+
return op.attrs['workload']
574+
wkl = traverse(op.input_tensors)
575+
if wkl:
576+
return wkl
577+
return None
578+
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
579+
return traverse(outs)

python/tvm/relay/backend/compile_engine.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ... import _api_internal
2626
from ... import target as _target
2727
from ..._ffi.function import register_func
28-
from ...autotvm import task as _task
28+
from ... import autotvm
2929
from .. import expr as _expr
3030
from .. import op as _op
3131
from .. import ty as _ty
@@ -97,6 +97,60 @@ def get_shape(shape):
9797
return ret
9898

9999

100+
def get_valid_implements(op, attrs, inputs, out_type, target):
101+
"""only use this function with concrete shapes"""
102+
fstrategy = op.get_attr("FTVMStrategy")
103+
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
104+
with target:
105+
strategy = fstrategy(attrs, inputs, out_type, target)
106+
ret = []
107+
for spec in strategy.specializations:
108+
if spec.condition:
109+
for clause in spec.condition.clauses:
110+
clause = tvm.ir_pass.Simplify(clause)
111+
if isinstance(clause, tvm.expr.IntImm) and int(clause):
112+
ret.append(impl)
113+
else:
114+
for impl in spec.implements:
115+
ret.append(impl)
116+
return ret
117+
118+
119+
def select_implement(op, attrs, inputs, out_type, target, use_autotvm=True):
120+
"""only use this function with concrete shapes"""
121+
all_impls = get_valid_implements(op, attrs, inputs, out_type, target)
122+
123+
best_plevel_impl = None
124+
for impl in all_impls:
125+
if best_plevel_impl is None or int(impl.plevel) > int(best_plevel_impl.plevel):
126+
best_plevel_impl = impl
127+
if not use_autotvm:
128+
outs = best_plevel_impl.compute(attrs, inputs, out_type)
129+
return best_plevel_impl, outs
130+
131+
outputs = {}
132+
best_autotvm_impl = None
133+
best_cfg = None
134+
dispatch_ctx = autotvm.task.DispatchContext.current
135+
for impl in all_impls:
136+
outs = impl.compute(attrs, inputs, out_type)
137+
outputs[impl] = outs
138+
workload = autotvm.task.get_workload(outs)
139+
if workload is None:
140+
continue
141+
workload = autotvm.task.args_to_workload2(workload)
142+
cfg = dispatch_ctx.query(target, workload)
143+
if cfg.cost is None:
144+
# It's a fallback config
145+
continue
146+
if best_cfg is None or best_cfg.cost > cfg.cost:
147+
best_autotvm_impl = impl
148+
best_cfg = cfg
149+
if best_autotvm_impl:
150+
return best_autotvm_impl, outputs[best_autotvm_impl]
151+
return best_plevel_impl, outputs[best_plevel_impl]
152+
153+
100154
class ScheduleGetter(ExprVisitor):
101155
"""Get the schedule given a fused Relay function"""
102156

@@ -199,35 +253,18 @@ def visit_call(self, call):
199253
outputs = [_api_internal._Tensor(copy_input.shape, copy_input.dtype,
200254
None, 0)]
201255
else:
202-
fstrategy = op.get_attr("FTVMStrategy")
203-
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
204-
strategy = fstrategy(call.attrs, inputs, ret_type, self.target)
205-
op_spec = None
206-
# TODO(@icemelon9): current only use the default specialization (with no condition)
207-
for spec in strategy.specializations:
208-
if spec.condition is None:
209-
op_spec = spec
210-
break
211-
assert op_spec is not None, \
212-
"Cannot find default specialization for op %s" % op.name
213-
assert len(op_spec.implements) > 0
214-
215256
is_dyn = call.checked_type.is_dynamic()
216257
for arg in call.args:
217258
is_dyn = is_dyn or arg.checked_type.is_dynamic()
218259

219260
if not is_dyn:
220-
best_imp = self.get_best_implement_by_autotvm(
221-
op_spec, call.attrs, inputs, ret_type)
222-
if best_imp is None:
223-
best_imp = self.get_best_implement_by_plevel(
224-
op_spec, call.attrs, inputs, ret_type)
261+
best_impl, outputs = select_implement(
262+
op, call.attrs, inputs, ret_type, self.target)
225263
else:
226-
# for dynamic case, we just use the implementation with highest score
227-
best_imp = self.get_best_implement_by_plevel(
228-
op_spec, call.attrs, inputs, ret_type)
229-
assert best_imp is not None
230-
outputs = best_imp.compute(call.attrs, inputs, ret_type)
264+
# TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes
265+
# for dynamic case, we currently use the implementation with highest plevel
266+
best_impl, outputs = select_implement(
267+
op, call.attrs, inputs, ret_type, self.target, use_autotvm=False)
231268
op_pattern = op.get_attr("TOpPattern")
232269
if op_pattern >= _op.OpPattern.COMM_REDUCE:
233270
assert self.master_op is None or self.master_op_pattern < _op.OpPattern.COMM_REDUCE, \
@@ -237,7 +274,7 @@ def visit_call(self, call):
237274
self.master_op = op
238275
self.master_attrs = call.attrs
239276
self.master_op_pattern = op_pattern
240-
self.master_implement = best_imp
277+
self.master_implement = best_impl
241278
if len(outputs) > 1:
242279
assert isinstance(call.checked_type, _ty.TupleType)
243280
assert len(call.checked_type.fields) == len(outputs)
@@ -269,30 +306,6 @@ def visit_tuple_getitem(self, op):
269306
assert op.index < tup.size()
270307
return [tup[op.index]]
271308

272-
def get_best_implement_by_autotvm(self, op_spec, attrs, inputs, ret_type):
273-
min_cost = None
274-
best_imp = None
275-
for imp in op_spec.implements:
276-
outs = imp.compute(attrs, inputs, ret_type)
277-
if 'workload' not in outs[0].op.attrs:
278-
continue
279-
workload = _task.args_to_workload2(outs[0].op.attrs['workload'])
280-
cfg = _task.DispatchContext.current.query(self.target, workload)
281-
if cfg.cost is None:
282-
# This is fallback config
283-
continue
284-
if min_cost is None or min_cost > cfg.cost:
285-
min_cost = cfg.cost
286-
best_imp = imp
287-
return best_imp
288-
289-
def get_best_implement_by_plevel(self, op_spec, attrs, inputs, ret_type):
290-
best_imp = None
291-
for imp in op_spec.implements:
292-
if best_imp is None or int(imp.plevel) > int(best_imp.plevel):
293-
best_imp = imp
294-
return best_imp
295-
296309

297310
@register_func("relay.backend.create_schedule")
298311
def create_schedule(src_func, target):

python/tvm/relay/op/nn/_nn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,9 @@ def _find_conv2d_op(op):
183183
reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)
184184

185185
@reg.register_alter_op_layout("nn.conv2d")
186-
def alter_op_layout_conv2d(attrs, inputs, tinfos):
186+
def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
187187
"""Alternate the layout of conv2d"""
188-
from ... import op
189-
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
188+
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
190189

191190
@reg.register_legalize("nn.conv2d")
192191
def legalize_conv2d(attrs, inputs, types):

python/tvm/relay/op/op.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -347,16 +347,6 @@ def _build(lowered_funcs):
347347
_schedule_injective = None
348348
_schedule_reduce = None
349349

350-
# def register_injective_schedule(target, schedule):
351-
# def wrap_schedule(_, outs):
352-
# return schedule(outs)
353-
# _injective_schedule_map.append([target, wrap_schedule])
354-
#
355-
# def register_reduce_schedule(target, schedule):
356-
# def wrap_schedule(_, outs):
357-
# return schedule(outs)
358-
# _reduce_schedule_map.append([target, wrap_schedule])
359-
360350
__DEBUG_COUNTER__ = 0
361351

362352
def debug(expr, debug_func=None):

0 commit comments

Comments
 (0)