Skip to content

Commit

Permalink
[tmva] Warn when using RBDT and xgboost
Browse files Browse the repository at this point in the history
Representing the current situation at #15197
  • Loading branch information
vepadulano committed Apr 11, 2024
1 parent 0d56835 commit d82c309
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,24 @@ def Compute(self, x):
# As fall-through we go to the original compute function and use the error-handling from cppyy
return self._OriginalCompute(x)

def RBDTInit(self, *args, **kwargs):
import warnings
warnings.warn(
("Usage of xgboost models through RBDT is known to be limited and may "
"lead to unexpected behaviour. Proceed with caution if the input model "
"was obtained via `SaveXGBoost`. See https://github.com/root-project/root/issues/15197 "
"for more details."), UserWarning, stacklevel=2)

return self._original_init(*args, **kwargs)


@pythonization("RBDT", ns="TMVA::Experimental", is_prefix=True)
def pythonize_rbdt(klass):
# Parameters:
# klass: class to be pythonized

klass._original_init = klass.__init__
klass.__init__ = RBDTInit

klass._OriginalCompute = klass.Compute
klass.Compute = Compute
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import cppyy


def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs=None, tmp_path="/tmp", threshold_dtype="float"):
def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/tmp", threshold_dtype="float"):
import warnings
warnings.warn(
("Usage of xgboost models through RBDT is known to be limited and may "
"lead to unexpected behaviour. See https://github.com/root-project/root/issues/15197 "
"for more details."), UserWarning, stacklevel=2)
# Extract objective
objective_map = {
"multi:softprob": "softmax", # Naming the objective softmax is more common today
Expand Down
5 changes: 3 additions & 2 deletions tmva/tmva/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ if(dataframe)
endif()

# Disabled because RBDT doesn't support the imbalanced tree structure of
# XGBoost models.
# if(dataframe AND NOT pyroot_legacy)
# XGBoost models. See https://github.com/root-project/root/issues/15197
# TODO: Re-enable once fixed
# if(dataframe)
# find_python_module(xgboost QUIET)
# if (PY_XGBOOST_FOUND)
# ROOT_ADD_PYUNITTEST(rbdt_xgboost rbdt_xgboost.py)
Expand Down
4 changes: 4 additions & 0 deletions tutorials/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ set(gui_veto fit/fitpanel_playback.C
if (NOT ROOT_tmva_FOUND)
list(APPEND tmva_veto tmva/*.C tmva/*.py tmva/envelope/*.C tmva/keras/*.C tmva/keras/*.py tmva/pytorch/*.py )
else()
# RBDT with xgboost is currently broken, disable all tutorials that use it
# See https://github.com/root-project/root/issues/15197
# TODO: Re-enable once fixed
list(APPEND tmva_veto tmva/tmva10[0-2]_*.py)
#---These do not need to run for TMVA
list(APPEND tmva_veto tmva/createData.C)
if(MSVC AND NOT win_broken_tests)
Expand Down

0 comments on commit d82c309

Please sign in to comment.