2121from itertools import chain as itertools_chain
2222from typing import TYPE_CHECKING , Any , Callable , Dict , List , NamedTuple , Optional , Tuple
2323
24+ from typing_extensions import Literal
25+
2426import numpy as np # type: ignore
2527
2628from ...contrib .tar import tar , untar
@@ -202,6 +204,8 @@ def average_peak_score(
202204class XGBConfig (NamedTuple ):
203205 """XGBoost model configuration
204206
207+ Reference: https://xgboost.readthedocs.io/en/stable/parameter.html
208+
205209 Parameters
206210 ----------
207211 max_depth : int
@@ -217,6 +221,8 @@ class XGBConfig(NamedTuple):
217221 nthread : Optional[int],
218222 The number of threads to use.
219223 Default is None, which means to use physical number of cores.
224+ tree_method : Literal["auto", "exact", "approx", "hist", "gpu_hist"]
225+ The tree construction algorithm used in XGBoost.
220226 """
221227
222228 max_depth : int = 10
@@ -225,15 +231,19 @@ class XGBConfig(NamedTuple):
225231 eta : float = 0.2
226232 seed : int = 43
227233 nthread : Optional [int ] = None
234+ tree_method : Literal ["auto" , "exact" , "approx" , "hist" , "gpu_hist" ] = "auto"
228235
229236 def to_dict (self ):
237+ """Convert to dict"""
238+
230239 return {
231240 "max_depth" : self .max_depth ,
232241 "gamma" : self .gamma ,
233242 "min_child_weight" : self .min_child_weight ,
234243 "eta" : self .eta ,
235244 "seed" : self .seed ,
236245 "nthread" : self .nthread ,
246+ "tree_method" : self .tree_method ,
237247 }
238248
239249
@@ -334,6 +344,7 @@ def __init__(
334344 average_peak_n : int = 32 ,
335345 adaptive_training : bool = True ,
336346 num_tuning_cores : Optional [int ] = None ,
347+ tree_method : Optional [Literal ["auto" , "exact" , "approx" , "hist" , "gpu_hist" ]] = None ,
337348 ):
338349 super ().__init__ ()
339350 if not isinstance (extractor , FeatureExtractor ):
@@ -348,6 +359,9 @@ def __init__(
348359 else :
349360 config = config ._replace (nthread = num_tuning_cores )
350361
362+ if tree_method is not None :
363+ config ._replace (tree_method = tree_method )
364+
351365 self .config = config
352366 # behavior of randomness
353367 self .num_warmup_samples = num_warmup_samples
0 commit comments