Skip to content

Commit 6f1de3b

Browse files
committed
[TMVA] Merge RBDT with FastForest
Consolidate RBDT as specified in the ROOT plan of work 2024. The backends of RBDT are replaced with a single new backend: the logic from the FastForest library: https://github.com/guitargeek/XGBoost-FastForest The logic in that library was originally taken from the GBRForest in CMSSW: https://github.com/cms-sw/cmssw/blob/master/CommonTools/MVAUtils/interface/GBRForestTools.h The interface remains the same, only that the template parameter specifying the backend is gone. This change adds support for unbalanced trees.
1 parent 7a39d4e commit 6f1de3b

File tree

16 files changed

+528
-939
lines changed

16 files changed

+528
-939
lines changed

bindings/pyroot/pythonizations/python/ROOT/_facade.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,11 @@ def TMVA(self):
347347
hasRDF = "dataframe" in gROOT.GetConfigFeatures()
348348
if hasRDF:
349349
try:
350-
from ._pythonization._tmva import inject_rbatchgenerator, _AsRTensor
350+
from ._pythonization._tmva import inject_rbatchgenerator, _AsRTensor, SaveXGBoost
351351

352352
inject_rbatchgenerator(ns)
353353
ns.Experimental.AsRTensor = _AsRTensor
354+
ns.Experimental.SaveXGBoost = SaveXGBoost
354355
except:
355356
raise Exception("Failed to pythonize the namespace TMVA")
356357
del type(self).TMVA

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def inject_rbatchgenerator(ns):
5151

5252
#this should be available only when xgboost is there ?
5353
# We probably don't need a protection here since the code is run only when there is xgboost
54-
from ._tree_inference import SaveXGBoost, pythonize_tree_inference
54+
from ._tree_inference import SaveXGBoost
5555

5656

5757
# list of python classes that are used to pythonize TMVA classes
Lines changed: 51 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,46 @@
1-
# Author: Stefan Wunsch CERN 09/2019
1+
# Author : Stefan Wunsch CERN 09 / 2019
22

33
################################################################################
4-
# Copyright (C) 1995-2019, Rene Brun and Fons Rademakers. #
5-
# All rights reserved. #
6-
# #
7-
# For the licensing terms see $ROOTSYS/LICENSE. #
8-
# For the list of contributors see $ROOTSYS/README/CREDITS. #
4+
# Copyright(C) 1995 - 2019, Rene Brun and Fons Rademakers.#
5+
# All rights reserved.#
6+
# #
7+
# For the licensing terms see $ROOTSYS / LICENSE.#
8+
# For the list of contributors see $ROOTSYS / README / CREDITS.#
99
################################################################################
1010

1111
from .. import pythonization
1212
import cppyy
1313

14+
import json
1415

15-
def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/tmp", threshold_dtype="float"):
16+
17+
def get_basescore(model):
18+
"""Get base score from an XGBoost sklearn estimator.
19+
20+
Copy-pasted from XGBoost unit test code.
21+
22+
See also:
23+
* https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
24+
* https://github.com/dmlc/xgboost/issues/9347
25+
* https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
26+
"""
27+
base_score = float(json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"])
28+
return base_score
29+
30+
31+
def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
32+
"""
33+
Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.
34+
35+
Args:
36+
xgb_model: The trained XGBoost model.
37+
key_name (str): The name to use for storing the RBDT in the output file.
38+
output_path (str): The path to save the output file.
39+
num_inputs (int): The number of input features used in the model.
40+
41+
Raises:
42+
Exception: If the XGBoost model has an unsupported objective.
43+
"""
1644
# Extract objective
1745
objective_map = {
1846
"multi:softprob": "softmax", # Naming the objective softmax is more common today
@@ -29,99 +57,25 @@ def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/t
2957
)
3058
objective = cppyy.gbl.std.string(objective_map[model_objective])
3159

32-
# Extract max depth of the trees
33-
max_depth = xgb_model.max_depth
34-
3560
# Determine number of outputs
36-
if "reg:" in model_objective:
37-
num_outputs = 1
38-
elif "binary:" in model_objective:
39-
num_outputs = 1
40-
else:
41-
num_outputs = xgb_model.n_classes_
61+
num_outputs = xgb_model.n_classes_ if "multi:" in model_objective else 1
62+
63+
# Dump XGB model as json file
64+
xgb_model.get_booster().dump_model(output_path, dump_format="json")
4265

43-
# Dump XGB model to the tmp folder as json file
44-
import os
45-
import uuid
66+
with open(output_path, "r") as json_file:
67+
forest = json.load(json_file)
4668

47-
tmp_path = os.path.join(tmp_path, str(uuid.uuid4()) + ".json")
48-
xgb_model.get_booster().dump_model(tmp_path, dump_format="json")
69+
# Dump XGB model as txt file
70+
xgb_model.get_booster().dump_model(output_path)
4971

50-
import json
72+
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
73+
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(output_path, features, num_outputs)
5174

52-
with open(tmp_path, "r") as json_file:
53-
forest = json.load(json_file)
75+
bdt.logistic_ = objective == "logistic"
76+
77+
bs = get_basescore(xgb_model)
78+
bdt.baseScore_ = cppyy.gbl.std.log(bs / (1.0 - bs)) if bdt.logistic_ else bs
5479

55-
# Determine whether the model has a bias paramter and write bias trees
56-
if hasattr(xgb_model, "base_score") and "reg:" in model_objective:
57-
bias = xgb_model.base_score
58-
if not bias == 0.0:
59-
forest += [{"leaf": bias}] * num_outputs
60-
# print(str(forest).replace("u'", "'").replace("'", '"'))
61-
62-
# Extract parameters from json and write to arrays
63-
num_trees = len(forest)
64-
len_inputs = 2 ** max_depth - 1
65-
inputs = cppyy.gbl.std.vector["int"](len_inputs * num_trees, -1)
66-
len_thresholds = 2 ** (max_depth + 1) - 1
67-
thresholds = cppyy.gbl.std.vector[threshold_dtype](len_thresholds * num_trees)
68-
69-
def fill_arrays(node, index, inputs_base, thresholds_base):
70-
# Set leaf score as threshold value if this node is a leaf
71-
if "leaf" in node:
72-
thresholds[thresholds_base + index] = node["leaf"]
73-
return
74-
75-
# Set input index
76-
input_ = int(node["split"].replace("f", ""))
77-
inputs[inputs_base + index] = input_
78-
79-
# Set threshold value
80-
thresholds[thresholds_base + index] = node["split_condition"]
81-
82-
# Find next left (no) and right (yes) node
83-
if node["children"][0]["nodeid"] == node["yes"]:
84-
yes, no = 1, 0
85-
else:
86-
yes, no = 0, 1
87-
88-
# Fill values from the child nodes
89-
fill_arrays(node["children"][no], 2 * index + 1, inputs_base, thresholds_base)
90-
fill_arrays(node["children"][yes], 2 * index + 2, inputs_base, thresholds_base)
91-
92-
for i_tree, tree in enumerate(forest):
93-
fill_arrays(tree, 0, len_inputs * i_tree, len_thresholds * i_tree)
94-
95-
# Determine to which output node a tree belongs
96-
outputs = cppyy.gbl.std.vector["int"](num_trees)
97-
if num_outputs != 1:
98-
for i in range(num_trees):
99-
outputs[i] = int(i % num_outputs)
100-
101-
# Store arrays in a ROOT file in a folder with the given key name
102-
# TODO: Write single values as simple integers and not vectors.
103-
f = cppyy.gbl.TFile(output_path, "RECREATE")
104-
f.mkdir(key_name)
105-
d = f.Get(key_name)
106-
d.WriteObjectAny(inputs, "std::vector<int>", "inputs")
107-
d.WriteObjectAny(outputs, "std::vector<int>", "outputs")
108-
d.WriteObjectAny(thresholds, "std::vector<" + threshold_dtype + ">", "thresholds")
109-
d.WriteObjectAny(objective, "std::string", "objective")
110-
max_depth_ = cppyy.gbl.std.vector["int"](1, max_depth)
111-
d.WriteObjectAny(max_depth_, "std::vector<int>", "max_depth")
112-
num_trees_ = cppyy.gbl.std.vector["int"](1, num_trees)
113-
d.WriteObjectAny(num_trees_, "std::vector<int>", "num_trees")
114-
num_inputs_ = cppyy.gbl.std.vector["int"](1, num_inputs)
115-
d.WriteObjectAny(num_inputs_, "std::vector<int>", "num_inputs")
116-
num_outputs_ = cppyy.gbl.std.vector["int"](1, num_outputs)
117-
d.WriteObjectAny(num_outputs_, "std::vector<int>", "num_outputs")
118-
f.Write()
119-
f.Close()
120-
121-
122-
@pythonization("SaveXGBoost", ns="TMVA::Experimental")
123-
def pythonize_tree_inference(klass):
124-
# Parameters:
125-
# klass: class to be pythonized
126-
127-
klass.__init__ = SaveXGBoost
80+
with cppyy.gbl.TFile.Open(output_path, "RECREATE") as tFile:
81+
tFile.WriteObject(bdt, key_name)

tmva/tmva/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,6 @@ ROOT_STANDARD_LIBRARY_PACKAGE(TMVAUtils
462462
TMVA/RBatchGenerator.hxx
463463
TMVA/RBatchLoader.hxx
464464
TMVA/RChunkLoader.hxx
465-
TMVA/TreeInference/PythonHelpers.hxx
466-
TMVA/TreeInference/BranchlessTree.hxx
467-
TMVA/TreeInference/Forest.hxx
468-
TMVA/TreeInference/Objectives.hxx
469465

470466
SOURCES
471467

tmva/tmva/inc/LinkDefUtils.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111

1212
#ifdef R__HAS_DATAFRAME
1313
// BDT inference
14-
#pragma link C++ class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessForest<float>>+;
15-
#pragma link C++ class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessJittedForest<float>>+;
14+
#pragma link C++ class TMVA::Experimental::RBDT+;
1615
#endif
1716

1817
// RTensor will have its own streamer function
1918
#pragma link C++ class TMVA::Experimental::RTensor<float,std::vector<float>>-;
2019

21-
#endif
20+
#endif

tmva/tmva/inc/TMVA/RBDT.hxx

Lines changed: 45 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
/**********************************************************************************
22
* Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
33
* Package: TMVA *
4-
* *
4+
* *
55
* *
66
* Description: *
77
* *
88
* Authors: *
99
* Stefan Wunsch (stefan.wunsch@cern.ch) *
10+
* Jonas Rembser (jonas.rembser@cern.ch) *
1011
* *
11-
* Copyright (c) 2019: *
12+
* Copyright (c) 2024: *
1213
* CERN, Switzerland *
1314
* *
1415
* Redistribution and use in source and binary forms, with or without *
@@ -19,108 +20,70 @@
1920
#ifndef TMVA_RBDT
2021
#define TMVA_RBDT
2122

22-
#include "TMVA/RTensor.hxx"
23-
#include "TMVA/TreeInference/Forest.hxx"
24-
#include "TFile.h"
23+
#include <Rtypes.h>
24+
#include <TMVA/RTensor.hxx>
2525

26-
#include <vector>
26+
#include <array>
27+
#include <istream>
2728
#include <string>
28-
#include <sstream> // std::stringstream
29-
#include <memory>
29+
#include <vector>
3030

3131
namespace TMVA {
32+
3233
namespace Experimental {
3334

34-
/// Fast boosted decision tree inference
35-
template <typename Backend = BranchlessJittedForest<float>>
36-
class RBDT {
35+
class RBDT final {
3736
public:
38-
using Value_t = typename Backend::Value_t;
39-
using Backend_t = Backend;
37+
typedef float Value_t;
4038

41-
private:
42-
int fNumOutputs;
43-
bool fNormalizeOutputs;
44-
std::vector<Backend_t> fBackends;
39+
/// IO constructor (both for ROOT IO and LoadText()).
40+
RBDT() = default;
4541

46-
public:
47-
/// Construct backends from model in ROOT file
48-
RBDT(const std::string &key, const std::string &filename)
49-
{
50-
// Get number of output nodes of the forest
51-
std::unique_ptr<TFile> file{TFile::Open(filename.c_str(),"READ")};
52-
if (!file || file->IsZombie()) {
53-
throw std::runtime_error("Failed to open input file " + filename);
54-
}
55-
auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file.get(), filename, key + "/num_outputs");
56-
fNumOutputs = numOutputs->at(0);
57-
delete numOutputs;
58-
59-
// Get objective and decide whether to normalize output nodes for example in the multiclass case
60-
auto objective = Internal::GetObjectSafe<std::string>(file.get(), filename, key + "/objective");
61-
if (objective->compare("softmax") == 0)
62-
fNormalizeOutputs = true;
63-
else
64-
fNormalizeOutputs = false;
65-
delete objective;
66-
file->Close();
67-
68-
// Initialize backends
69-
fBackends = std::vector<Backend_t>(fNumOutputs);
70-
for (int i = 0; i < fNumOutputs; i++)
71-
fBackends[i].Load(key, filename, i);
72-
}
42+
/// Construct backends from model in ROOT file.
43+
RBDT(const std::string &key, const std::string &filename);
7344

74-
/// Compute model prediction on a single event
45+
/// Compute model prediction on a single event.
7546
///
7647
/// The method is intended to be used with std::vectors-like containers,
7748
/// for example RVecs.
7849
template <typename Vector>
79-
Vector Compute(const Vector &x)
50+
Vector Compute(const Vector &x) const
8051
{
81-
Vector y;
82-
y.resize(fNumOutputs);
83-
for (int i = 0; i < fNumOutputs; i++)
84-
fBackends[i].Inference(&x[0], 1, true, &y[i]);
85-
if (fNormalizeOutputs) {
86-
Value_t s = 0.0;
87-
for (int i = 0; i < fNumOutputs; i++)
88-
s += y[i];
89-
for (int i = 0; i < fNumOutputs; i++)
90-
y[i] /= s;
91-
}
52+
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
53+
Vector y(nOut);
54+
ComputeImpl(x.data(), y.data());
9255
return y;
9356
}
9457

95-
/// Compute model prediction on a single event
96-
std::vector<Value_t> Compute(const std::vector<Value_t> &x) { return this->Compute<std::vector<Value_t>>(x); }
58+
/// Compute model prediction on a single event.
59+
inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }
9760

98-
/// Compute model prediction on input RTensor
99-
RTensor<Value_t> Compute(const RTensor<Value_t> &x)
100-
{
101-
const auto rows = x.GetShape()[0];
102-
RTensor<Value_t> y({rows, static_cast<std::size_t>(fNumOutputs)}, MemoryLayout::ColumnMajor);
103-
const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
104-
for (int i = 0; i < fNumOutputs; i++)
105-
fBackends[i].Inference(x.GetData(), rows, layout, &y(0, i));
106-
if (fNormalizeOutputs) {
107-
Value_t s;
108-
for (int i = 0; i < static_cast<int>(rows); i++) {
109-
s = 0.0;
110-
for (int j = 0; j < fNumOutputs; j++)
111-
s += y(i, j);
112-
for (int j = 0; j < fNumOutputs; j++)
113-
y(i, j) /= s;
114-
}
115-
}
116-
return y;
117-
}
118-
};
61+
RTensor<Value_t> Compute(RTensor<Value_t> const &x) const;
11962

120-
extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessForest<float>>;
121-
extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessJittedForest<float>>;
63+
static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses = 2);
64+
static RBDT LoadText(std::istream &is, std::vector<std::string> &features, int nClasses = 2);
65+
66+
std::vector<int> rootIndices_;
67+
std::vector<unsigned int> cutIndices_;
68+
std::vector<Value_t> cutValues_;
69+
std::vector<int> leftIndices_;
70+
std::vector<int> rightIndices_;
71+
std::vector<Value_t> responses_;
72+
std::vector<int> treeNumbers_;
73+
std::vector<Value_t> baseResponses_;
74+
Value_t baseScore_ = 0.0;
75+
bool logistic_ = false;
76+
77+
private:
78+
void Softmax(const Value_t *array, Value_t *out) const;
79+
void ComputeImpl(const Value_t *array, Value_t *out) const;
80+
Value_t EvaluateBinary(const Value_t *array) const;
81+
82+
ClassDefNV(RBDT, 1);
83+
};
12284

12385
} // namespace Experimental
86+
12487
} // namespace TMVA
12588

12689
#endif // TMVA_RBDT

0 commit comments

Comments
 (0)