Skip to content

Commit

Permalink
[AutoScheduler] Fix the conflict of thread pool in measurement (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and trevor-m committed Jan 21, 2021
1 parent a4de1d1 commit 87673b8
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 158 deletions.
19 changes: 10 additions & 9 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,9 @@ def _timed_func(inp_serialized, build_func, verbose):

if verbose >= 1:
if error_no == MeasureErrorNo.NO_ERROR:
print(".", end="")
print(".", end="", flush=True)
else:
print(".E", end="") # Build error
print(".E", end="", flush=True) # Build error

return filename, args, error_no, error_msg, time.time() - tic

Expand Down Expand Up @@ -634,11 +634,11 @@ def local_build_worker(args):
res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
if verbose >= 1:
print(".T", end="") # Build timeout
print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
if verbose >= 1:
print(".E", end="") # Build error
print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout

return res
Expand Down Expand Up @@ -751,9 +751,9 @@ def _timed_eval_func(

if verbose >= 1:
if error_no == MeasureErrorNo.NO_ERROR:
print("*", end="")
print("*", end="", flush=True)
else:
print("*E", end="") # Run error
print("*E", end="", flush=True) # Run error
return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc


Expand Down Expand Up @@ -839,10 +839,11 @@ def local_run(
enable_cpu_cache_flush,
verbose,
),
add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
print("*T", end="", flush=True) # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
Expand All @@ -852,7 +853,7 @@ def local_run(
)
elif isinstance(res, Exception):
if verbose >= 1:
print("*E", end="") # Run error
print("*E", end="", flush=True) # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
Expand All @@ -864,7 +865,7 @@ def local_run(
measure_results.append(MeasureResult(*res))

if verbose >= 1:
print("")
print("", flush=True)

return measure_results

Expand Down
51 changes: 43 additions & 8 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

""" Common utilities for auto_scheduler. """

Expand Down Expand Up @@ -162,22 +163,56 @@ def make_traceback_info():
return info


def _func_wrapper(que, func, args, kwargs):
class PropagatingThread(threading.Thread):
"""A thread that propagates the exception to the main thread"""

def run(self):
self.exc = None
try:
self.ret = self._target(*self._args, **self._kwargs)
except Exception as e: # pylint: disable=broad-except
self.exc = e

def join(self, timeout=None):
super(PropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret


def call_func_with_thread(func, args, kwargs):
"""Call a function within a new thread"""
res = []

def wrapper():
res.append(func(*args, **kwargs))

t = PropagatingThread(target=wrapper)
t.start()
t.join()
return res[0]


def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
"""Call function and return the result over the queue."""
try:
if kwargs:
que.put(func(*args, **kwargs))
if add_thread_wrapper:
# Add a new layer of threadinng to avoid the conflict between
# python's multiprocessing and tvm's thread pool.
res = call_func_with_thread(func, args, kwargs)
else:
que.put(func(*args))
# pylint: disable=broad-except
except Exception:
res = func(*args, **kwargs)
que.put(res)
except Exception: # pylint: disable=broad-except
que.put(Exception(make_traceback_info()))


def call_func_with_timeout(timeout, func, args=(), kwargs=None):
def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False):
"""Call a function with timeout"""
que = multiprocessing.Queue(2)
process = multiprocessing.Process(target=_func_wrapper, args=(que, func, args, kwargs))
process = multiprocessing.Process(
target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper)
)
process.start()

try:
Expand Down
18 changes: 0 additions & 18 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def test_something():
import os
import sys
import time
import threading
import pytest
import numpy as np
import tvm
Expand Down Expand Up @@ -743,21 +742,4 @@ def terminate_self():
sys.exit(-1)


class PropagatingThread(threading.Thread):
"""A thread that propagates the exection to the main thread"""

def run(self):
self.exc = None
try:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e

def join(self, timeout=None):
super(PropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret


tvm._ffi._init_api("testing", __name__)
13 changes: 3 additions & 10 deletions tests/python/relay/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tvm import relay, auto_scheduler
from tvm.contrib import graph_runtime
import tvm.testing
from tvm.testing import PropagatingThread


def get_np_array(var, dtype):
Expand Down Expand Up @@ -139,23 +138,17 @@ def test_conv2d():
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
mod, data, weight = get_relay_conv2d(kh=1, kw=1)
t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
t.start()
t.join()
tune_and_check(mod, data, weight)


def test_dense():
mod, data, weight = get_relay_dense()
t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
t.start()
t.join()
tune_and_check(mod, data, weight)


def test_batch_matmul():
mod, data, weight = get_relay_batchmm()
t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
t.start()
t.join()
tune_and_check(mod, data, weight)


if __name__ == "__main__":
Expand Down
88 changes: 14 additions & 74 deletions tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import tvm
import tvm.testing
from tvm.testing import PropagatingThread
from tvm import auto_scheduler

from test_auto_scheduler_common import matmul_auto_scheduler_test
Expand Down Expand Up @@ -78,18 +77,12 @@ def search_common(
num_measures_per_round=2,
early_stopping=1,
runner=runner,
verbose=2,
measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
)
task.tune(tuning_options=tuning_options, search_policy=search_policy)
sch, args = task.apply_best(log_file)

print("==== Python Code ====")
print(task.print_best(log_file))

try:
print("==== Lowered Stmt ====")
print(tvm.lower(sch, args, simple_mode=True))
mod = tvm.build(sch, args, target)

ctx = tvm.context(str(target), 0)
Expand All @@ -99,52 +92,29 @@ def search_common(
c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
mod(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
print("==== Verification passed ====")
except Exception:
raise Exception("Error encountered with seed: %d" % (seed))
print()


@tvm.testing.requires_llvm
def test_workload_registry_search_basic():
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(
target=search_common, kwargs={"search_policy": "empty", "num_measure_trials": 2}
)
t.start()
t.join()

t = PropagatingThread(
target=search_common,
kwargs={
"workload": "matmul_auto_scheduler_test",
"num_measure_trials": 2,
"search_policy": "empty",
},
search_common(search_policy="empty", num_measure_trials=2)

search_common(
workload="matmul_auto_scheduler_test",
num_measure_trials=2,
search_policy="empty",
)
t.start()
t.join()

t = PropagatingThread(
target=search_common,
kwargs={
"workload": "matmul_auto_scheduler_test_rename_1",
"num_measure_trials": 2,
"search_policy": "empty",
},
search_common(
workload="matmul_auto_scheduler_test_rename_1",
num_measure_trials=2,
search_policy="empty",
)
t.start()
t.join()


@tvm.testing.requires_llvm
def test_sketch_search_policy_basic():
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common)
t.start()
t.join()
search_common()


def sketch_search_policy_basic_spawn():
Expand All @@ -162,49 +132,19 @@ def test_sketch_search_policy_basic_spawn():

@tvm.testing.requires_llvm
def test_sketch_search_policy_xgbmodel():
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(
target=search_common,
kwargs={
"cost_model": auto_scheduler.XGBModel(),
},
)
t.start()
t.join()
search_common(cost_model=auto_scheduler.XGBModel())


@tvm.testing.requires_cuda
def test_sketch_search_policy_cuda_rpc_runner():
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(
target=search_common,
kwargs={
"target": "cuda",
"runner": measure_ctx.runner,
},
)
t.start()
t.join()
search_common(target="cuda", runner=measure_ctx.runner)


@tvm.testing.requires_cuda
def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(
target=search_common,
kwargs={
"target": "cuda",
"runner": measure_ctx.runner,
"cost_model": auto_scheduler.XGBModel(),
},
)
t.start()
t.join()
search_common(target="cuda", runner=measure_ctx.runner, cost_model=auto_scheduler.XGBModel())


if __name__ == "__main__":
Expand Down
34 changes: 20 additions & 14 deletions tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,24 @@ def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
# and resume the status of search policy and cost model with the log file.
# In the example below we resume the status and do more 5 trials.

cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
task.tune(tune_option, search_policy=search_policy)

# Kill the measurement process
del measure_ctx
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
task.tune(tune_option, search_policy=search_policy)

# Kill the measurement process
del measure_ctx


resume_search(task, log_file)
Loading

0 comments on commit 87673b8

Please sign in to comment.