Skip to content

Commit 0794875

Browse files
committed
Fix xgb error & Simplify dispatcher (apache#35)
1 parent 2c27816 commit 0794875

File tree

14 files changed

+70
-240
lines changed

14 files changed

+70
-240
lines changed

python/tvm/ansor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
workload_key_to_dag, make_workload_key_func
4141
from .task_scheduler import TaskScheduler, SimpleTaskScheduler
4242
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \
43-
FallbackContext, clear_fallback_cache, ApplyGraphBest
43+
FallbackContext
4444
from .relay_integration import extract_from_program, extract_from_multiple_program, \
4545
finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi
4646
from .env import GLOBAL_SCOPE

python/tvm/ansor/auto_schedule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ class MetaTileRewritePolicy(SearchPolicy):
9797
seed: int
9898
Random seed
9999
"""
100-
101100
def __init__(self,
102101
program_cost_model,
103102
params=None,

python/tvm/ansor/compute_dag.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def get_init_state(self):
5353

5454
def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE):
5555
"""
56+
Apply transform steps according to the history of a state
57+
5658
Parameters
5759
----------
5860
state : StateObject
@@ -68,6 +70,8 @@ def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.
6870

6971
def print_python_code_from_state(self, state):
7072
"""
73+
Print transform steps in the history of a state as TVM's python schedule primitive
74+
7175
Parameters
7276
----------
7377
state : StateObject
@@ -81,16 +85,29 @@ def print_python_code_from_state(self, state):
8185

8286
def infer_bound_from_state(self, state):
8387
"""
88+
Infer bound for a state
89+
8490
Parameters
8591
----------
8692
state : StateObject
8793
8894
Returns
8995
-------
90-
state : StateObject
96+
state : State
9197
"""
9298
state_obj = state if isinstance(state, StateObject) else state.state_object
9399
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
94100

95101
def rewrite_layout_from_state(self, state: State):
102+
"""
103+
Rewrite the layout according to the transform steps in the history of a state
104+
105+
Parameters
106+
----------
107+
state : StateObject
108+
109+
Returns
110+
-------
111+
state : StateObject
112+
"""
96113
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state)

python/tvm/ansor/cost_model/cost_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@
2626

2727
@tvm._ffi.register_object("ansor.CostModel")
2828
class CostModel(Object):
29+
"""The base class for cost model"""
2930
pass
3031

3132

3233
@tvm._ffi.register_object("ansor.RandomModel")
3334
class RandomModel(Object):
35+
"""A model returns random estimation for all inputs"""
3436
def __init__(self):
3537
self.__init_handle_by_constructor__(_ffi_api.RandomModel)
3638

3739

38-
# A random number generator func for c++'s RandomModel
3940
@tvm._ffi.register_func("ansor.cost_model.random_number")
4041
def random_number(n, return_ptr):
42+
""" A random number generator func for c++'s RandomModel """
4143
if n == 0:
4244
return
4345
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
@@ -47,6 +49,7 @@ def random_number(n, return_ptr):
4749

4850
@tvm._ffi.register_object("ansor.PythonBasedModel")
4951
class PythonBasedModel(CostModel):
52+
"""Base class for cost models implemented in python"""
5053
def __init__(self):
5154
def update_func(inputs, results):
5255
self.update(inputs, results)

python/tvm/ansor/cost_model/xgb_model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
# under the License.
1717

1818
"""Cost model based on xgboost"""
19-
from typing import List
2019
import multiprocessing
2120
import logging
22-
import time
2321
from collections import defaultdict
2422

2523
import numpy as np
2624
import xgboost as xgb
2725

28-
from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve
26+
from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve
2927
from .cost_model import PythonBasedModel
3028
from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states
3129
from ..serialization import LogReader
@@ -65,8 +63,8 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
6563
# todo(lmzheng): automatically decrease learning rate when the loss is too large
6664

6765
'n_gpus': 0,
68-
'n_threads': multiprocessing.cpu_count() / 2,
69-
'silent': 0,
66+
'nthread': multiprocessing.cpu_count() // 2,
67+
'verbosity': 0,
7068
'seed': seed or 43,
7169
'disable_default_eval_metric': 1
7270
}
@@ -180,7 +178,7 @@ def pack_sum_xgbmatrix_for_prediction(xs):
180178
x_flatten.append(row)
181179
pack_ids.append(ct)
182180

183-
return xgb.DMatrix(x_flatten), pack_ids
181+
return xgb.DMatrix(np.array(x_flatten)), pack_ids
184182

185183

186184
def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None):
@@ -214,7 +212,7 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None):
214212
y_flatten.append(y)
215213
pack_ids.append(ct)
216214

217-
ret = xgb.DMatrix(x_flatten, y_flatten)
215+
ret = xgb.DMatrix(np.array(x_flatten), y_flatten)
218216
if weights is not None:
219217
ret.set_weight(weights_flatten)
220218
dmatrix_context.put('pack_ids', ret, np.array(pack_ids))

0 commit comments

Comments
 (0)