Skip to content

Commit

Permalink
fixup: address review comments by Vincenzo
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Apr 9, 2024
1 parent a893ef1 commit e610073
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
# Author: Stefan Wunsch CERN 09/2019
# Author : Stefan Wunsch CERN 09 / 2019

################################################################################
# Copyright (C) 1995-2019, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
# Copyright(C) 1995 - 2019, Rene Brun and Fons Rademakers.#
# All rights reserved.#
# #
# For the licensing terms see $ROOTSYS / LICENSE.#
# For the list of contributors see $ROOTSYS / README / CREDITS.#
################################################################################

from .. import pythonization
import cppyy

import json

def get_basescore(model):
import json

def get_basescore(model):
"""Get base score from an XGBoost sklearn estimator.
Copy-pasted from XGBoost unit test code.
See also:
* https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
* https://github.com/dmlc/xgboost/issues/9347
* https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
"""
base_score = float(json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"])
return base_score


def SaveXGBoost(xgb_model, key_name, output_path, num_inputs=None):
def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
"""
Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.
import json
Args:
xgb_model: The trained XGBoost model.
key_name (str): The name to use for storing the RBDT in the output file.
output_path (str): The path to save the output file.
num_inputs (int): The number of input features used in the model.
Raises:
Exception: If the XGBoost model has an unsupported objective.
"""
# Extract objective
objective_map = {
"multi:softprob": "softmax", # Naming the objective softmax is more common today
Expand All @@ -47,26 +61,18 @@ def SaveXGBoost(xgb_model, key_name, output_path, num_inputs=None):
max_depth = xgb_model.max_depth

# Determine number of outputs
num_outputs = 1
if "multi:" in model_objective:
num_outputs = xgb_model.n_classes_
num_outputs = xgb_model.n_classes_ if "multi:" in model_objective else 1

# Dump XGB model as json file
xgb_model.get_booster().dump_model(output_path, dump_format="json")

with open(output_path, "r") as json_file:
forest = json.load(json_file)

# Determine number of input variables
if num_inputs is None:
raise Exception(
"Failed to get number of input variables from XGBoost model. Please provide the additional keyword argument 'num_inputs' to this function."
)

xgb_model._Booster.dump_model(output_path)
xgb_model.get_booster().dump_model(output_path)

features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
bdt = cppyy.gbl.TMVA.Experimental.RBDT.load_txt(output_path, features, num_outputs)
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(output_path, features, num_outputs)

bdt.logistic_ = objective == "logistic"
if not bdt.logistic_:
Expand Down
27 changes: 14 additions & 13 deletions tmva/tmva/inc/TMVA/RBDT.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,36 @@ namespace TMVA {

namespace Experimental {

class RBDT {
class RBDT final {
public:
typedef float Value_t;

/// IO constructor (both for ROOT IO and LoadText()).
RBDT() = default;

/// Construct backends from model in ROOT file
/// Construct backends from model in ROOT file.
RBDT(const std::string &key, const std::string &filename);

/// Compute model prediction on a single event
/// Compute model prediction on a single event.
///
/// The method is intended to be used with std::vectors-like containers,
/// for example RVecs.
template <typename Vector>
Vector Compute(const Vector &x)
Vector Compute(const Vector &x) const
{
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
Vector y(nOut);
compute(x.data(), y.data());
ComputeImpl(x.data(), y.data());
return y;
}

/// Compute model prediction on a single event
inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) { return Compute<std::vector<Value_t>>(x); }
/// Compute model prediction on a single event.
inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }

RTensor<Value_t> Compute(RTensor<Value_t> const &x);
RTensor<Value_t> Compute(RTensor<Value_t> const &x) const;

static RBDT load_txt(std::string const &txtpath, std::vector<std::string> &features, int nClasses = 2);
static RBDT load_txt(std::istream &is, std::vector<std::string> &features, int nClasses = 2);
static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses = 2);
static RBDT LoadText(std::istream &is, std::vector<std::string> &features, int nClasses = 2);

std::vector<int> rootIndices_;
std::vector<unsigned int> cutIndices_;
Expand All @@ -74,9 +75,9 @@ public:
bool logistic_ = false;

private:
void softmax(const Value_t *array, Value_t *out) const;
void compute(const Value_t *array, Value_t *out) const;
Value_t evaluateBinary(const Value_t *array) const;
void Softmax(const Value_t *array, Value_t *out) const;
void ComputeImpl(const Value_t *array, Value_t *out) const;
Value_t EvaluateBinary(const Value_t *array) const;

ClassDefNV(RBDT, 1);
};
Expand Down
54 changes: 19 additions & 35 deletions tmva/tmva/src/RBDT.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

#include <TMVA/RBDT.hxx>

#include <TFile.h>
#include <ROOT/RSpan.hxx>
#include <ROOT/StringUtils.hxx>

#include <TFile.h>
#include <TSystem.h>

#include <cmath>
#include <fstream>
Expand Down Expand Up @@ -127,35 +130,14 @@ inline NumericAfterSubstrOutput<NumericType> numericAfterSubstr(std::string cons
return output;
}

std::vector<std::string> split(std::string const &strToSplit, char delimeter)
{
std::stringstream ss(strToSplit);
std::string item;
std::vector<std::string> splittedStrings;
while (std::getline(ss, item, delimeter)) {
splittedStrings.push_back(item);
}
return splittedStrings;
}

bool exists(std::string const &filename)
{
if (FILE *file = fopen(filename.c_str(), "r")) {
fclose(file);
return true;
} else {
return false;
}
}

} // namespace util

} // namespace

using namespace TMVA::Experimental;
using TMVA::Experimental::RTensor;

/// Compute model prediction on input RTensor
RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTensor<Value_t> const &x)
RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTensor<Value_t> const &x) const
{
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
const std::size_t rows = x.GetShape()[0];
Expand All @@ -167,15 +149,15 @@ RTensor<TMVA::Experimental::RBDT::Value_t> TMVA::Experimental::RBDT::Compute(RTe
for (std::size_t iCol = 0; iCol < cols; ++iCol) {
xRow[iCol] = x({iRow, iCol});
}
compute(xRow.data(), yRow.data());
ComputeImpl(xRow.data(), yRow.data());
for (std::size_t iOut = 0; iOut < nOut; ++iOut) {
y({iRow, iOut}) = yRow[iOut];
}
}
return y;
}

void TMVA::Experimental::RBDT::softmax(const Value_t *array, Value_t *out) const
void TMVA::Experimental::RBDT::Softmax(const Value_t *array, Value_t *out) const
{
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
if (nOut == 1) {
Expand All @@ -202,20 +184,20 @@ void TMVA::Experimental::RBDT::softmax(const Value_t *array, Value_t *out) const
detail::softmaxTransformInplace(out, nOut);
}

void TMVA::Experimental::RBDT::compute(const Value_t *array, Value_t *out) const
void TMVA::Experimental::RBDT::ComputeImpl(const Value_t *array, Value_t *out) const
{
std::size_t nOut = baseResponses_.size() > 2 ? baseResponses_.size() : 1;
if (nOut > 1) {
softmax(array, out);
Softmax(array, out);
} else {
out[0] = evaluateBinary(array);
out[0] = EvaluateBinary(array);
if (logistic_) {
out[0] = 1.0 / (1.0 + std::exp(-out[0]));
}
}
}

TMVA::Experimental::RBDT::Value_t TMVA::Experimental::RBDT::evaluateBinary(const Value_t *array) const
TMVA::Experimental::RBDT::Value_t TMVA::Experimental::RBDT::EvaluateBinary(const Value_t *array) const
{
Value_t out = baseScore_ + baseResponses_[0];

Expand Down Expand Up @@ -260,19 +242,21 @@ void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPrev

} // namespace

RBDT TMVA::Experimental::RBDT::load_txt(std::string const &txtpath, std::vector<std::string> &features, int nClasses)
TMVA::Experimental::RBDT
TMVA::Experimental::RBDT::LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses)
{
const std::string info = "constructing RBDT from " + txtpath + ": ";

if (!util::exists(txtpath)) {
if (gSystem->AccessPathName(txtpath.c_str())) {
throw std::runtime_error(info + "file does not exists");
}

std::ifstream file(txtpath.c_str());
return load_txt(file, features, nClasses);
return LoadText(file, features, nClasses);
}

RBDT TMVA::Experimental::RBDT::load_txt(std::istream &file, std::vector<std::string> &features, int nClasses)
TMVA::Experimental::RBDT
TMVA::Experimental::RBDT::LoadText(std::istream &file, std::vector<std::string> &features, int nClasses)
{
const std::string info = "constructing RBDT from istream: ";

Expand Down Expand Up @@ -314,7 +298,7 @@ RBDT TMVA::Experimental::RBDT::load_txt(std::istream &file, std::vector<std::str
ss >> index;
line = ss.str();

std::vector<std::string> splitstring = util::split(subline, '<');
std::vector<std::string> splitstring = ROOT::Split(subline, "<");
std::string const &varName = splitstring[0];
Value_t cutValue;
{
Expand Down
2 changes: 1 addition & 1 deletion tutorials/tmva/tmva103_Application.C
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void tmva103_Application()
const char* model_filename = "tmva101.root";

if (gSystem->AccessPathName(model_filename)) {
Info("tmva102_Testing.py", "%s does not exist", model_filename);
Info("tmva103_Application.C", "%s does not exist", model_filename);
return;
}

Expand Down

0 comments on commit e610073

Please sign in to comment.