Skip to content

Commit cd0a516

Browse files
jcf94merrymercy
authored andcommitted
Code refine for tune_test.py & Add a pre load callback (apache#20)
* Bug fix for tutorials * Add PreLoadMeasuredStates * Add search_callback support for task tuner * Code refine for tune_test.py * Update * Update * Update * Update * Bug fix
1 parent 74ec7d0 commit cd0a516

File tree

16 files changed

+355
-193
lines changed

16 files changed

+355
-193
lines changed

python/tvm/ansor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# Shortcut
3131
from .compute_dag import ComputeDAG
32-
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams
32+
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback
3333
from .auto_schedule import auto_schedule
3434
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext
3535
from .cost_model import RandomModel

python/tvm/ansor/auto_schedule.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,15 @@ def __init__(self, dag, workload_key, target, target_host=None,
6969
class SearchPolicy(Object):
7070
def continue_search(self, task, num_measure, verbose, measurer):
7171
return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer)
72+
73+
def set_task(self, task):
74+
_ffi_api.SearchPolicySetTask(self, task);
7275

76+
def set_verbose(self, verbose):
77+
_ffi_api.SearchPolicySetVerbose(self, verbose);
78+
79+
def run_callbacks(self, callbacks):
80+
_ffi_api.SearchPolicyRunCallbacks(self, callbacks)
7381

7482
@tvm._ffi.register_object("ansor.MetaTileRewritePolicy")
7583
class MetaTileRewritePolicy(SearchPolicy):
@@ -117,6 +125,21 @@ def __init__(self,
117125
seed or random.randint(1, 1 << 30))
118126

119127

128+
@tvm._ffi.register_object("ansor.SearchCallback")
129+
class SearchCallback(Object):
130+
pass
131+
132+
133+
@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback")
134+
class PreLoadMeasuredStatesCallback(SearchCallback):
135+
""" A SearchCallback that used for search policy to load measured hash
136+
from the log file.
137+
"""
138+
def __init__(self, filename: str):
139+
self.__init_handle_by_constructor__(
140+
_ffi_api.PreLoadMeasuredStatesCallback, filename)
141+
142+
120143
@tvm._ffi.register_object("ansor.TuneOption")
121144
class TuneOption(Object):
122145
""" The options for tuning
@@ -135,11 +158,13 @@ class TuneOption(Object):
135158
Builder which builds the program
136159
runner: Runner
137160
Runner which runs the program and measure time costs
138-
callbacks: List[MeasureCallback]
161+
measure_callbacks: List[MeasureCallback]
139162
Callback functions
163+
pre_search_callbacks: List[SearchCallback]
140164
"""
141165
def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64,
142-
verbose=1, builder='local', runner='local', callbacks=None):
166+
verbose=1, builder='local', runner='local', measure_callbacks=None,
167+
pre_search_callbacks=None):
143168
if isinstance(builder, str):
144169
if builder == 'local':
145170
builder = LocalBuilder()
@@ -152,12 +177,15 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64,
152177
else:
153178
raise ValueError("Invalid builder: " + runner)
154179

155-
if callbacks is None:
156-
callbacks = []
180+
if measure_callbacks is None:
181+
measure_callbacks = []
182+
183+
if pre_search_callbacks is None:
184+
pre_search_callbacks = []
157185

158186
self.__init_handle_by_constructor__(
159187
_ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter,
160-
verbose, builder, runner, callbacks)
188+
verbose, builder, runner, measure_callbacks, pre_search_callbacks)
161189

162190

163191
def auto_schedule(workload, target=None,

python/tvm/ansor/measure.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ def __init__(self,
174174

175175
@tvm._ffi.register_object("ansor.ProgramMeasurer")
176176
class ProgramMeasurer(Object):
177+
"""
178+
Parameters
179+
----------
180+
builder : Builder
181+
runner : Runner
182+
callbacks : List[MeasureCallback]
183+
verbose : Int
184+
max_continuous_error : Float
185+
"""
186+
177187
def __init__(self, builder: Builder, runner: Runner,
178188
callbacks: List[MeasureCallback],
179189
verbose: int, max_continuous_error: int = -1):
@@ -182,6 +192,21 @@ def __init__(self, builder: Builder, runner: Runner,
182192

183193
@tvm._ffi.register_object("ansor.RPCRunner")
184194
class RPCRunner(Runner):
195+
"""
196+
Parameters
197+
----------
198+
key : Str
199+
host : Str
200+
port : Int
201+
priority : Int
202+
n_parallel : Int
203+
timeout : Int
204+
number : Int
205+
repeat : Int
206+
min_repeat_ms : Int
207+
cooldown_interval : Float
208+
"""
209+
185210
def __init__(self, key, host, port, priority=1,
186211
n_parallel=1,
187212
timeout=10,
@@ -203,6 +228,19 @@ def __init__(self, key, host, port, priority=1,
203228

204229

205230
class LocalRPCMeasureContext:
231+
""" A context wrapper for RPCRunner.
232+
233+
Parameters
234+
----------
235+
priority : Int
236+
n_parallel : Int
237+
timeout : Int
238+
number : Int
239+
repeat : Int
240+
min_repeat_ms : Int
241+
cooldown_interval : Float
242+
"""
243+
206244
def __init__(self,
207245
priority=1,
208246
n_parallel=1,
@@ -228,8 +266,8 @@ def __init__(self,
228266
time.sleep(0.5)
229267

230268
def __del__(self):
231-
self.tracker.terminate()
232269
self.server.terminate()
270+
self.tracker.terminate()
233271

234272

235273
class MeasureErrorNo(object):

python/tvm/ansor/task_scheduler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol
153153
self.tune_option = tune_option
154154
if self.use_debug_measurement_simulator is None:
155155
self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner,
156-
tune_option.callbacks, tune_option.verbose)
156+
tune_option.measure_callbacks, tune_option.verbose)
157157
self.ct = 0
158158
self.tic = time.time()
159159
# reset num_measure_per_iter to make sure every task is tuned at least once
@@ -167,6 +167,13 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol
167167
self.sequential_now_task_idx = 0
168168
self.sequential_now_task_begin_ct = 0
169169

170+
for i in range(len(self.tasks)):
171+
search_policy = self.search_policies[i]
172+
task = self.tasks[i]
173+
search_policy.set_task(task)
174+
search_policy.set_verbose(tune_option.verbose)
175+
search_policy.run_callbacks(tune_option.pre_search_callbacks)
176+
170177
# do a round robin first
171178
if self.strategy != 'sequential':
172179
for i in range(len(self.tasks)):

0 commit comments

Comments
 (0)