Skip to content

Commit e610073

Browse files
committed
fixup: address review comments by Vincenzo
1 parent a893ef1 commit e610073

File tree

4 files changed

+61
-70
lines changed

4 files changed

+61
-70
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_tree_inference.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
1-
# Author: Stefan Wunsch CERN 09/2019
1+
# Author : Stefan Wunsch CERN 09 / 2019
22

33
################################################################################
4-
# Copyright (C) 1995-2019, Rene Brun and Fons Rademakers. #
5-
# All rights reserved. #
6-
# #
7-
# For the licensing terms see $ROOTSYS/LICENSE. #
8-
# For the list of contributors see $ROOTSYS/README/CREDITS. #
4+
# Copyright(C) 1995 - 2019, Rene Brun and Fons Rademakers.#
5+
# All rights reserved.#
6+
# #
7+
# For the licensing terms see $ROOTSYS / LICENSE.#
8+
# For the list of contributors see $ROOTSYS / README / CREDITS.#
99
################################################################################
1010

1111
from .. import pythonization
1212
import cppyy
1313

14+
import json
1415

15-
def get_basescore(model):
16-
import json
1716

17+
def get_basescore(model):
1818
"""Get base score from an XGBoost sklearn estimator.
1919
2020
Copy-pasted from XGBoost unit test code.
21+
22+
See also:
23+
* https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
24+
* https://github.com/dmlc/xgboost/issues/9347
25+
* https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
2126
"""
2227
base_score = float(json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"])
2328
return base_score
2429

2530

26-
def SaveXGBoost(xgb_model, key_name, output_path, num_inputs=None):
31+
def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
32+
"""
33+
Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.
2734
28-
import json
35+
Args:
36+
xgb_model: The trained XGBoost model.
37+
key_name (str): The name to use for storing the RBDT in the output file.
38+
output_path (str): The path to save the output file.
39+
num_inputs (int): The number of input features used in the model.
2940
41+
Raises:
42+
Exception: If the XGBoost model has an unsupported objective.
43+
"""
3044
# Extract objective
3145
objective_map = {
3246
"multi:softprob": "softmax", # Naming the objective softmax is more common today
@@ -47,26 +61,18 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs=None):
4761
max_depth = xgb_model.max_depth
4862

4963
# Determine number of outputs
50-
num_outputs = 1
51-
if "multi:" in model_objective:
52-
num_outputs = xgb_model.n_classes_
64+
num_outputs = xgb_model.n_classes_ if "multi:" in model_objective else 1
5365

5466
# Dump XGB model as json file
5567
xgb_model.get_booster().dump_model(output_path, dump_format="json")
5668

5769
with open(output_path, "r") as json_file:
5870
forest = json.load(json_file)
5971

60-
# Determine number of input variables
61-
if num_inputs is None:
62-
raise Exception(
63-
"Failed to get number of input variables from XGBoost model. Please provide the additional keyword argument 'num_inputs' to this function."
64-
)
65-
66-
xgb_model._Booster.dump_model(output_path)
72+
xgb_model.get_booster().dump_model(output_path)
6773

6874
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
69-
bdt = cppyy.gbl.TMVA.Experimental.RBDT.load_txt(output_path, features, num_outputs)
75+
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(output_path, features, num_outputs)
7076

7177
bdt.logistic_ = objective == "logistic"
7278
if not bdt.logistic_:

tmva/tmva/inc/TMVA/RBDT.hxx

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,35 +32,36 @@ namespace TMVA {
3232

3333
namespace Experimental {
3434

35-
class RBDT {
35+
class RBDT final {
3636
public:
3737
typedef float Value_t;
3838

39+
/// IO constructor (both for ROOT IO and LoadText()).
3940
RBDT() = default;
4041

41-
/// Construct backends from model in ROOT file
42+
/// Construct backends from model in ROOT file.
4243
RBDT(const std::string &key, const std::string &filename);
4344

44-
/// Compute model prediction on a single event
45+
/// Compute model prediction on a single event.
4546
///
4647
/// The method is intended to be used with std::vectors-like containers,
4748
/// for example RVecs.
4849
template <typename Vector>
49-
Vector Compute(const Vector &x)
50+
Vector Compute(const Vector &x) const
5051
{
5152
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
5253
Vector y(nOut);
53-
compute(x.data(), y.data());
54+
ComputeImpl(x.data(), y.data());
5455
return y;
5556
}
5657

57-
/// Compute model prediction on a single event
58-
inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) { return Compute<std::vector<Value_t>>(x); }
58+
/// Compute model prediction on a single event.
59+
inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }
5960

60-
RTensor<Value_t> Compute(RTensor<Value_t> const &x);
61+
RTensor<Value_t> Compute(RTensor<Value_t> const &x) const;
6162

62-
static RBDT load_txt(std::string const &txtpath, std::vector<std::string> &features, int nClasses = 2);
63-
static RBDT load_txt(std::istream &is, std::vector<std::string> &features, int nClasses = 2);
63+
static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses = 2);
64+
static RBDT LoadText(std::istream &is, std::vector<std::string> &features, int nClasses = 2);
6465

6566
std::vector<int> rootIndices_;
6667
std::vector<unsigned int> cutIndices_;
@@ -74,9 +75,9 @@ public:
7475
bool logistic_ = false;
7576

7677
private:
77-
void softmax(const Value_t *array, Value_t *out) const;
78-
void compute(const Value_t *array, Value_t *out) const;
79-
Value_t evaluateBinary(const Value_t *array) const;
78+
void Softmax(const Value_t *array, Value_t *out) const;
79+
void ComputeImpl(const Value_t *array, Value_t *out) const;
80+
Value_t EvaluateBinary(const Value_t *array) const;
8081

8182
ClassDefNV(RBDT, 1);
8283
};

tmva/tmva/src/RBDT.cxx

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818

1919
#include <TMVA/RBDT.hxx>
2020

21-
#include <TFile.h>
2221
#include <ROOT/RSpan.hxx>
22+
#include <ROOT/StringUtils.hxx>
23+
24+
#include <TFile.h>
25+
#include <TSystem.h>
2326

2427
#include <cmath>
2528
#include <fstream>
@@ -127,35 +130,14 @@ inline NumericAfterSubstrOutput<NumericType> numericAfterSubstr(std::string cons
127130
return output;
128131
}
129132

130-
std::vector<std::string> split(std::string const &strToSplit, char delimeter)
131-
{
132-
std::stringstream ss(strToSplit);
133-
std::string item;
134-
std::vector<std::string> splittedStrings;
135-
while (std::getline(ss, item, delimeter)) {
136-
splittedStrings.push_back(item);
137-
}
138-
return splittedStrings;
139-
}
140-
141-
bool exists(std::string const &filename)
142-
{
143-
if (FILE *file = fopen(filename.c_str(), "r")) {
144-
fclose(file);
145-
return true;
146-
} else {
147-
return false;
148-
}
149-
}
150-
151133
} // namespace util
152134

153135
} // namespace
154136

155-
using namespace TMVA::Experimental;
137+
using TMVA::Experimental::RTensor;
156138

157139
/// Compute model prediction on input RTensor
158-
RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTensor<Value_t> const &x)
140+
RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTensor<Value_t> const &x) const
159141
{
160142
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
161143
const std::size_t rows = x.GetShape()[0];
@@ -167,15 +149,15 @@ RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTe
167149
for (std::size_t iCol = 0; iCol < cols; ++iCol) {
168150
xRow[iCol] = x({iRow, iCol});
169151
}
170-
compute(xRow.data(), yRow.data());
152+
ComputeImpl(xRow.data(), yRow.data());
171153
for (std::size_t iOut = 0; iOut < nOut; ++iOut) {
172154
y({iRow, iOut}) = yRow[iOut];
173155
}
174156
}
175157
return y;
176158
}
177159

178-
void TMVA::Experimental::RBDT::softmax(const Value_t *array, Value_t *out) const
160+
void TMVA::Experimental::RBDT::Softmax(const Value_t *array, Value_t *out) const
179161
{
180162
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
181163
if (nOut == 1) {
@@ -202,20 +184,20 @@ void TMVA::Experimental::RBDT::softmax(const Value_t *array, Value_t *out) const
202184
detail::softmaxTransformInplace(out, nOut);
203185
}
204186

205-
void TMVA::Experimental::RBDT::compute(const Value_t *array, Value_t *out) const
187+
void TMVA::Experimental::RBDT::ComputeImpl(const Value_t *array, Value_t *out) const
206188
{
207189
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
208190
if (nOut > 1) {
209-
softmax(array, out);
191+
Softmax(array, out);
210192
} else {
211-
out[0] = evaluateBinary(array);
193+
out[0] = EvaluateBinary(array);
212194
if (logistic_) {
213195
out[0] = 1.0 / (1.0 + std::exp(-out[0]));
214196
}
215197
}
216198
}
217199

218-
TMVA::Experimental::RBDT::Value_t TMVA::Experimental::RBDT::evaluateBinary(const Value_t *array) const
200+
TMVA::Experimental::RBDT::Value_t TMVA::Experimental::RBDT::EvaluateBinary(const Value_t *array) const
219201
{
220202
Value_t out = baseScore_ + baseResponses_[0];
221203

@@ -260,19 +242,21 @@ void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPrev
260242

261243
} // namespace
262244

263-
RBDT TMVA::Experimental::RBDT::load_txt(std::string const &txtpath, std::vector<std::string> &features, int nClasses)
245+
TMVA::Experimental::RBDT
246+
TMVA::Experimental::RBDT::LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses)
264247
{
265248
const std::string info = "constructing RBDT from " + txtpath + ": ";
266249

267-
if (!util::exists(txtpath)) {
250+
if (gSystem->AccessPathName(txtpath.c_str())) {
268251
throw std::runtime_error(info + "file does not exists");
269252
}
270253

271254
std::ifstream file(txtpath.c_str());
272-
return load_txt(file, features, nClasses);
255+
return LoadText(file, features, nClasses);
273256
}
274257

275-
RBDT TMVA::Experimental::RBDT::load_txt(std::istream &file, std::vector<std::string> &features, int nClasses)
258+
TMVA::Experimental::RBDT
259+
TMVA::Experimental::RBDT::LoadText(std::istream &file, std::vector<std::string> &features, int nClasses)
276260
{
277261
const std::string info = "constructing RBDT from istream: ";
278262

@@ -314,7 +298,7 @@ RBDT TMVA::Experimental::RBDT::load_txt(std::istream &file, std::vector<std::str
314298
ss >> index;
315299
line = ss.str();
316300

317-
std::vector<std::string> splitstring = util::split(subline, '<');
301+
std::vector<std::string> splitstring = ROOT::Split(subline, "<");
318302
std::string const &varName = splitstring[0];
319303
Value_t cutValue;
320304
{

tutorials/tmva/tmva103_Application.C

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void tmva103_Application()
1818
const char* model_filename = "tmva101.root";
1919

2020
if (gSystem->AccessPathName(model_filename)) {
21-
Info("tmva102_Testing.py", "%s does not exist", model_filename);
21+
Info("tmva103_Application.C", "%s does not exist", model_filename);
2222
return;
2323
}
2424

0 commit comments

Comments
 (0)