Skip to content

Commit d1f7ef4

Browse files
authored
[XGBoost,MetaSchedule] Support xgb set tree method (apache#15133)
1 parent 64ac43a commit d1f7ef4

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

python/tvm/meta_schedule/cost_model/cost_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,12 @@ def create(
127127
if kind == "xgb":
128128
return XGBModel(*args, **kwargs) # type: ignore
129129

130-
if "num_tuning_cores" in kwargs:
131-
# num_tuning_cores is only relevant for XGBModel.
132-
kwargs.pop("num_tuning_cores")
130+
# params only relevant to XGBModel
131+
_xgb_params = ["num_tuning_cores", "tree_method"]
132+
133+
for param in _xgb_params:
134+
if param in kwargs:
135+
kwargs.pop(param)
133136

134137
if kind == "random":
135138
return RandomModel(*args, **kwargs) # type: ignore

python/tvm/meta_schedule/cost_model/xgb_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from itertools import chain as itertools_chain
2222
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple
2323

24+
from typing_extensions import Literal
25+
2426
import numpy as np # type: ignore
2527

2628
from ...contrib.tar import tar, untar
@@ -202,6 +204,8 @@ def average_peak_score(
202204
class XGBConfig(NamedTuple):
203205
"""XGBoost model configuration
204206
207+
Reference: https://xgboost.readthedocs.io/en/stable/parameter.html
208+
205209
Parameters
206210
----------
207211
max_depth : int
@@ -217,6 +221,8 @@ class XGBConfig(NamedTuple):
217221
nthread : Optional[int],
218222
The number of threads to use.
219223
Default is None, which means to use physical number of cores.
224+
tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"]
225+
The tree construction algorithm used in XGBoost.
220226
"""
221227

222228
max_depth: int = 10
@@ -225,15 +231,19 @@ class XGBConfig(NamedTuple):
225231
eta: float = 0.2
226232
seed: int = 43
227233
nthread: Optional[int] = None
234+
tree_method: Literal["auto", "exact", "approx", "hist", "gpu_hist"] = "auto"
228235

229236
def to_dict(self):
237+
"""Convert to dict"""
238+
230239
return {
231240
"max_depth": self.max_depth,
232241
"gamma": self.gamma,
233242
"min_child_weight": self.min_child_weight,
234243
"eta": self.eta,
235244
"seed": self.seed,
236245
"nthread": self.nthread,
246+
"tree_method": self.tree_method,
237247
}
238248

239249

@@ -334,6 +344,7 @@ def __init__(
334344
average_peak_n: int = 32,
335345
adaptive_training: bool = True,
336346
num_tuning_cores: Optional[int] = None,
347+
tree_method: Optional[Literal["auto", "exact", "approx", "hist", "gpu_hist"]] = None,
337348
):
338349
super().__init__()
339350
if not isinstance(extractor, FeatureExtractor):
@@ -348,6 +359,9 @@ def __init__(
348359
else:
349360
config = config._replace(nthread=num_tuning_cores)
350361

362+
if tree_method is not None:
363+
config._replace(tree_method=tree_method)
364+
351365
self.config = config
352366
# behavior of randomness
353367
self.num_warmup_samples = num_warmup_samples

python/tvm/meta_schedule/tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def tune_tasks(
108108
elif not isinstance(database, Database):
109109
database = Database.create(database, module_equality=module_equality)
110110
if not isinstance(cost_model, CostModel):
111-
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores)
111+
cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto")
112112
if isinstance(measure_callbacks, MeasureCallback):
113113
measure_callbacks = [measure_callbacks]
114114
elif measure_callbacks == "default":

0 commit comments

Comments
 (0)