2525from ... import _api_internal
2626from ... import target as _target
2727from ..._ffi .function import register_func
28- from ...autotvm import task as _task
28+ from ... import autotvm
2929from .. import expr as _expr
3030from .. import op as _op
3131from .. import ty as _ty
@@ -97,6 +97,60 @@ def get_shape(shape):
9797 return ret
9898
9999
100+ def get_valid_implements (op , attrs , inputs , out_type , target ):
101+ """only use this function with concrete shapes"""
102+ fstrategy = op .get_attr ("FTVMStrategy" )
103+ assert fstrategy is not None , "%s doesn't have FTVMStrategy registered" % op .name
104+ with target :
105+ strategy = fstrategy (attrs , inputs , out_type , target )
106+ ret = []
107+ for spec in strategy .specializations :
108+ if spec .condition :
109+ for clause in spec .condition .clauses :
110+ clause = tvm .ir_pass .Simplify (clause )
111+ if isinstance (clause , tvm .expr .IntImm ) and int (clause ):
112+ ret .append (impl )
113+ else :
114+ for impl in spec .implements :
115+ ret .append (impl )
116+ return ret
117+
118+
119+ def select_implement (op , attrs , inputs , out_type , target , use_autotvm = True ):
120+ """only use this function with concrete shapes"""
121+ all_impls = get_valid_implements (op , attrs , inputs , out_type , target )
122+
123+ best_plevel_impl = None
124+ for impl in all_impls :
125+ if best_plevel_impl is None or int (impl .plevel ) > int (best_plevel_impl .plevel ):
126+ best_plevel_impl = impl
127+ if not use_autotvm :
128+ outs = best_plevel_impl .compute (attrs , inputs , out_type )
129+ return best_plevel_impl , outs
130+
131+ outputs = {}
132+ best_autotvm_impl = None
133+ best_cfg = None
134+ dispatch_ctx = autotvm .task .DispatchContext .current
135+ for impl in all_impls :
136+ outs = impl .compute (attrs , inputs , out_type )
137+ outputs [impl ] = outs
138+ workload = autotvm .task .get_workload (outs )
139+ if workload is None :
140+ continue
141+ workload = autotvm .task .args_to_workload2 (workload )
142+ cfg = dispatch_ctx .query (target , workload )
143+ if cfg .cost is None :
144+ # It's a fallback config
145+ continue
146+ if best_cfg is None or best_cfg .cost > cfg .cost :
147+ best_autotvm_impl = impl
148+ best_cfg = cfg
149+ if best_autotvm_impl :
150+ return best_autotvm_impl , outputs [best_autotvm_impl ]
151+ return best_plevel_impl , outputs [best_plevel_impl ]
152+
153+
100154class ScheduleGetter (ExprVisitor ):
101155 """Get the schedule given a fused Relay function"""
102156
@@ -199,35 +253,18 @@ def visit_call(self, call):
199253 outputs = [_api_internal ._Tensor (copy_input .shape , copy_input .dtype ,
200254 None , 0 )]
201255 else :
202- fstrategy = op .get_attr ("FTVMStrategy" )
203- assert fstrategy is not None , "%s doesn't have FTVMStrategy registered" % op .name
204- strategy = fstrategy (call .attrs , inputs , ret_type , self .target )
205- op_spec = None
206- # TODO(@icemelon9): current only use the default specialization (with no condition)
207- for spec in strategy .specializations :
208- if spec .condition is None :
209- op_spec = spec
210- break
211- assert op_spec is not None , \
212- "Cannot find default specialization for op %s" % op .name
213- assert len (op_spec .implements ) > 0
214-
215256 is_dyn = call .checked_type .is_dynamic ()
216257 for arg in call .args :
217258 is_dyn = is_dyn or arg .checked_type .is_dynamic ()
218259
219260 if not is_dyn :
220- best_imp = self .get_best_implement_by_autotvm (
221- op_spec , call .attrs , inputs , ret_type )
222- if best_imp is None :
223- best_imp = self .get_best_implement_by_plevel (
224- op_spec , call .attrs , inputs , ret_type )
261+ best_impl , outputs = select_implement (
262+ op , call .attrs , inputs , ret_type , self .target )
225263 else :
226- # for dynamic case, we just use the implementation with highest score
227- best_imp = self .get_best_implement_by_plevel (
228- op_spec , call .attrs , inputs , ret_type )
229- assert best_imp is not None
230- outputs = best_imp .compute (call .attrs , inputs , ret_type )
264+ # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes
265+ # for dynamic case, we currently use the implementation with highest plevel
266+ best_impl , outputs = select_implement (
267+ op , call .attrs , inputs , ret_type , self .target , use_autotvm = False )
231268 op_pattern = op .get_attr ("TOpPattern" )
232269 if op_pattern >= _op .OpPattern .COMM_REDUCE :
233270 assert self .master_op is None or self .master_op_pattern < _op .OpPattern .COMM_REDUCE , \
@@ -237,7 +274,7 @@ def visit_call(self, call):
237274 self .master_op = op
238275 self .master_attrs = call .attrs
239276 self .master_op_pattern = op_pattern
240- self .master_implement = best_imp
277+ self .master_implement = best_impl
241278 if len (outputs ) > 1 :
242279 assert isinstance (call .checked_type , _ty .TupleType )
243280 assert len (call .checked_type .fields ) == len (outputs )
@@ -269,30 +306,6 @@ def visit_tuple_getitem(self, op):
269306 assert op .index < tup .size ()
270307 return [tup [op .index ]]
271308
272- def get_best_implement_by_autotvm (self , op_spec , attrs , inputs , ret_type ):
273- min_cost = None
274- best_imp = None
275- for imp in op_spec .implements :
276- outs = imp .compute (attrs , inputs , ret_type )
277- if 'workload' not in outs [0 ].op .attrs :
278- continue
279- workload = _task .args_to_workload2 (outs [0 ].op .attrs ['workload' ])
280- cfg = _task .DispatchContext .current .query (self .target , workload )
281- if cfg .cost is None :
282- # This is fallback config
283- continue
284- if min_cost is None or min_cost > cfg .cost :
285- min_cost = cfg .cost
286- best_imp = imp
287- return best_imp
288-
289- def get_best_implement_by_plevel (self , op_spec , attrs , inputs , ret_type ):
290- best_imp = None
291- for imp in op_spec .implements :
292- if best_imp is None or int (imp .plevel ) > int (best_imp .plevel ):
293- best_imp = imp
294- return best_imp
295-
296309
297310@register_func ("relay.backend.create_schedule" )
298311def create_schedule (src_func , target ):
0 commit comments