@@ -52,8 +52,7 @@ def _build(func,
5252def extract_from_program (func , params , ops , target , target_host = None ):
5353 """ Extract tuning tasks from a relay program.
5454
55- This function collects tuning tasks by building the program
56- with a "tracing" target and tracing all the calls to topi.
55+ This function is the single program version of extract_from_multiple_program.
5756
5857 Parameters
5958 ----------
@@ -73,66 +72,14 @@ def extract_from_program(func, params, ops, target, target_host=None):
7372 task: Array of autotvm.task.Task
7473 collected tasks
7574 """
76- import tvm .relay .op
77- from tvm import relay
78- import topi
79-
80- env = TaskExtractEnv .get ()
81-
82- # NOTE: To add more ops, you only need to change the following lists
83- # relay op -> topi compute
84- OP2TOPI = {
85- tvm .relay .op .nn .conv2d : [topi .nn .conv2d , topi .nn .depthwise_conv2d_nchw ,
86- topi .nn .group_conv2d_nchw , topi .nn .conv2d_NCHWc ],
87- tvm .relay .op .nn .conv2d_transpose : [topi .nn .conv2d_transpose_nchw ],
88- tvm .relay .op .nn .dense : [topi .nn .dense ],
89- tvm .relay .op .nn .deformable_conv2d : [topi .nn .deformable_conv2d_nchw ],
90- }
91-
92- topi_funcs = []
93- for op_name in ops :
94- if op_name in OP2TOPI :
95- topi_funcs .extend (OP2TOPI [op_name ])
96- else :
97- warnings .warn ("Op %s is not tunable, ignored" % op_name )
98-
99- # run compiler to collect all TOPI calls during compilation
100- env .reset (topi_funcs )
101- with env :
102- # disable logger temporarily
103- old_state = logger .disabled
104- logger .disabled = True
105-
106- relay .backend .compile_engine .get ().clear ()
107- # wrap build call in thread to avoid multiprocessing problems
108- mod = relay .Module .from_expr (func )
109- build_thread = threading .Thread (target = _build ,
110- args = (mod ,
111- target ,
112- target_host ,
113- params ))
114- build_thread .start ()
115- build_thread .join ()
116-
117- logger .disabled = old_state
118-
119- # create tasks for target
120- tasks = []
121- for task_name , args in env .get_tasks ():
122- try :
123- tsk = create (task_name , args ,
124- target = target , target_host = target_host ,
125- template_key = 'direct' )
126- tasks .append (tsk )
127- except topi .InvalidShapeError :
128- warnings .warn ("Invalid shape during AutoTVM task creation" )
129- return tasks
75+ return extract_from_multiple_program ([func ], [params ], ops , target , target_host )
13076
13177
13278def extract_from_multiple_program (funcs , params , ops , target , target_host = None ):
13379 """ Extract tuning tasks from multiple relay programs.
13480
135- This function is the multiple program version of extract_from_program
81+ This function collects tuning tasks by building a list of programs
82+ with a "tracing" target and tracing all the calls to topi.
13683
13784 Parameters
13885 ----------
@@ -152,19 +99,20 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
15299 task: Array of autotvm.task.Task
153100 collected tasks
154101 """
155- env = TaskExtractEnv .get ()
156102 import tvm .relay .op
157103 from tvm import relay
158104 import topi
159105
106+ env = TaskExtractEnv .get ()
107+
160108 # NOTE: To add more ops, you only need to change the following lists
161109 # relay op -> topi compute
162110 OP2TOPI = {
163111 tvm .relay .op .nn .conv2d : [topi .nn .conv2d , topi .nn .depthwise_conv2d_nchw ,
164- topi .nn .group_conv2d_nchw ],
112+ topi .nn .group_conv2d_nchw , topi . nn . conv2d_NCHWc ],
165113 tvm .relay .op .nn .conv2d_transpose : [topi .nn .conv2d_transpose_nchw ],
166114 tvm .relay .op .nn .dense : [topi .nn .dense ],
167- tvm .relay .op .nn .contrib_deformable_conv2d : [topi .nn .deformable_conv2d_nchw ],
115+ tvm .relay .op .nn .deformable_conv2d : [topi .nn .deformable_conv2d_nchw ],
168116 }
169117
170118 topi_funcs = []
@@ -185,11 +133,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
185133 relay .backend .compile_engine .get ().clear ()
186134 # wrap build call in thread to avoid multiprocessing problems
187135 mod = relay .Module .from_expr (func )
188- build_thread = threading .Thread (target = my_build ,
189- args = (mod ,
190- target ,
191- target_host ,
192- params ))
136+ build_thread = threading .Thread (target = _build ,
137+ args = (mod , target , target_host , param ))
193138 build_thread .start ()
194139 build_thread .join ()
195140
0 commit comments