1
- # Author: Stefan Wunsch CERN 09/ 2019
1
+ # Author : Stefan Wunsch CERN 09 / 2019
2
2
3
3
################################################################################
4
- # Copyright (C) 1995- 2019, Rene Brun and Fons Rademakers. #
5
- # All rights reserved. #
6
- # #
7
- # For the licensing terms see $ROOTSYS/ LICENSE. #
8
- # For the list of contributors see $ROOTSYS/ README/ CREDITS. #
4
+ # Copyright(C) 1995 - 2019, Rene Brun and Fons Rademakers.#
5
+ # All rights reserved.#
6
+ # #
7
+ # For the licensing terms see $ROOTSYS / LICENSE.#
8
+ # For the list of contributors see $ROOTSYS / README / CREDITS.#
9
9
################################################################################
10
10
11
11
from .. import pythonization
12
12
import cppyy
13
13
14
+ import json
14
15
15
- def SaveXGBoost (self , xgb_model , key_name , output_path , num_inputs , tmp_path = "/tmp" , threshold_dtype = "float" ):
16
+
17
+ def get_basescore (model ):
18
+ """Get base score from an XGBoost sklearn estimator.
19
+
20
+ Copy-pasted from XGBoost unit test code.
21
+
22
+ See also:
23
+ * https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
24
+ * https://github.com/dmlc/xgboost/issues/9347
25
+ * https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
26
+ """
27
+ base_score = float (json .loads (model .get_booster ().save_config ())["learner" ]["learner_model_param" ]["base_score" ])
28
+ return base_score
29
+
30
+
31
+ def SaveXGBoost (xgb_model , key_name , output_path , num_inputs ):
32
+ """
33
+ Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.
34
+
35
+ Args:
36
+ xgb_model: The trained XGBoost model.
37
+ key_name (str): The name to use for storing the RBDT in the output file.
38
+ output_path (str): The path to save the output file.
39
+ num_inputs (int): The number of input features used in the model.
40
+
41
+ Raises:
42
+ Exception: If the XGBoost model has an unsupported objective.
43
+ """
16
44
# Extract objective
17
45
objective_map = {
18
46
"multi:softprob" : "softmax" , # Naming the objective softmax is more common today
@@ -29,99 +57,25 @@ def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/t
29
57
)
30
58
objective = cppyy .gbl .std .string (objective_map [model_objective ])
31
59
32
- # Extract max depth of the trees
33
- max_depth = xgb_model .max_depth
34
-
35
60
# Determine number of outputs
36
- if "reg:" in model_objective :
37
- num_outputs = 1
38
- elif "binary:" in model_objective :
39
- num_outputs = 1
40
- else :
41
- num_outputs = xgb_model .n_classes_
61
+ num_outputs = xgb_model .n_classes_ if "multi:" in model_objective else 1
62
+
63
+ # Dump XGB model as json file
64
+ xgb_model .get_booster ().dump_model (output_path , dump_format = "json" )
42
65
43
- # Dump XGB model to the tmp folder as json file
44
- import os
45
- import uuid
66
+ with open (output_path , "r" ) as json_file :
67
+ forest = json .load (json_file )
46
68
47
- tmp_path = os . path . join ( tmp_path , str ( uuid . uuid4 ()) + ".json" )
48
- xgb_model .get_booster ().dump_model (tmp_path , dump_format = "json" )
69
+ # Dump XGB model as txt file
70
+ xgb_model .get_booster ().dump_model (output_path )
49
71
50
- import json
72
+ features = cppyy .gbl .std .vector ["std::string" ]([f"f{ i } " for i in range (num_inputs )])
73
+ bdt = cppyy .gbl .TMVA .Experimental .RBDT .LoadText (output_path , features , num_outputs )
51
74
52
- with open (tmp_path , "r" ) as json_file :
53
- forest = json .load (json_file )
75
+ bdt .logistic_ = objective == "logistic"
76
+
77
+ bs = get_basescore (xgb_model )
78
+ bdt .baseScore_ = cppyy .gbl .std .log (bs / (1.0 - bs )) if bdt .logistic_ else bs
54
79
55
- # Determine whether the model has a bias paramter and write bias trees
56
- if hasattr (xgb_model , "base_score" ) and "reg:" in model_objective :
57
- bias = xgb_model .base_score
58
- if not bias == 0.0 :
59
- forest += [{"leaf" : bias }] * num_outputs
60
- # print(str(forest).replace("u'", "'").replace("'", '"'))
61
-
62
- # Extract parameters from json and write to arrays
63
- num_trees = len (forest )
64
- len_inputs = 2 ** max_depth - 1
65
- inputs = cppyy .gbl .std .vector ["int" ](len_inputs * num_trees , - 1 )
66
- len_thresholds = 2 ** (max_depth + 1 ) - 1
67
- thresholds = cppyy .gbl .std .vector [threshold_dtype ](len_thresholds * num_trees )
68
-
69
- def fill_arrays (node , index , inputs_base , thresholds_base ):
70
- # Set leaf score as threshold value if this node is a leaf
71
- if "leaf" in node :
72
- thresholds [thresholds_base + index ] = node ["leaf" ]
73
- return
74
-
75
- # Set input index
76
- input_ = int (node ["split" ].replace ("f" , "" ))
77
- inputs [inputs_base + index ] = input_
78
-
79
- # Set threshold value
80
- thresholds [thresholds_base + index ] = node ["split_condition" ]
81
-
82
- # Find next left (no) and right (yes) node
83
- if node ["children" ][0 ]["nodeid" ] == node ["yes" ]:
84
- yes , no = 1 , 0
85
- else :
86
- yes , no = 0 , 1
87
-
88
- # Fill values from the child nodes
89
- fill_arrays (node ["children" ][no ], 2 * index + 1 , inputs_base , thresholds_base )
90
- fill_arrays (node ["children" ][yes ], 2 * index + 2 , inputs_base , thresholds_base )
91
-
92
- for i_tree , tree in enumerate (forest ):
93
- fill_arrays (tree , 0 , len_inputs * i_tree , len_thresholds * i_tree )
94
-
95
- # Determine to which output node a tree belongs
96
- outputs = cppyy .gbl .std .vector ["int" ](num_trees )
97
- if num_outputs != 1 :
98
- for i in range (num_trees ):
99
- outputs [i ] = int (i % num_outputs )
100
-
101
- # Store arrays in a ROOT file in a folder with the given key name
102
- # TODO: Write single values as simple integers and not vectors.
103
- f = cppyy .gbl .TFile (output_path , "RECREATE" )
104
- f .mkdir (key_name )
105
- d = f .Get (key_name )
106
- d .WriteObjectAny (inputs , "std::vector<int>" , "inputs" )
107
- d .WriteObjectAny (outputs , "std::vector<int>" , "outputs" )
108
- d .WriteObjectAny (thresholds , "std::vector<" + threshold_dtype + ">" , "thresholds" )
109
- d .WriteObjectAny (objective , "std::string" , "objective" )
110
- max_depth_ = cppyy .gbl .std .vector ["int" ](1 , max_depth )
111
- d .WriteObjectAny (max_depth_ , "std::vector<int>" , "max_depth" )
112
- num_trees_ = cppyy .gbl .std .vector ["int" ](1 , num_trees )
113
- d .WriteObjectAny (num_trees_ , "std::vector<int>" , "num_trees" )
114
- num_inputs_ = cppyy .gbl .std .vector ["int" ](1 , num_inputs )
115
- d .WriteObjectAny (num_inputs_ , "std::vector<int>" , "num_inputs" )
116
- num_outputs_ = cppyy .gbl .std .vector ["int" ](1 , num_outputs )
117
- d .WriteObjectAny (num_outputs_ , "std::vector<int>" , "num_outputs" )
118
- f .Write ()
119
- f .Close ()
120
-
121
-
122
- @pythonization ("SaveXGBoost" , ns = "TMVA::Experimental" )
123
- def pythonize_tree_inference (klass ):
124
- # Parameters:
125
- # klass: class to be pythonized
126
-
127
- klass .__init__ = SaveXGBoost
80
+ with cppyy .gbl .TFile .Open (output_path , "RECREATE" ) as tFile :
81
+ tFile .WriteObject (bdt , key_name )
0 commit comments