diff --git a/lgbm/src/booster.rs b/lgbm/src/booster.rs index 74ccde5..10dd0a3 100644 --- a/lgbm/src/booster.rs +++ b/lgbm/src/booster.rs @@ -10,11 +10,11 @@ use lgbm_sys::{ LGBM_BoosterFree, LGBM_BoosterGetCurrentIteration, LGBM_BoosterGetEval, LGBM_BoosterGetEvalCounts, LGBM_BoosterGetEvalNames, LGBM_BoosterGetFeatureNames, LGBM_BoosterGetNumClasses, LGBM_BoosterGetNumFeature, LGBM_BoosterGetNumPredict, - LGBM_BoosterGetPredict, LGBM_BoosterLoadModelFromString, LGBM_BoosterPredictForMat, - LGBM_BoosterSaveModel, LGBM_BoosterSaveModelToString, LGBM_BoosterUpdateOneIter, - C_API_FEATURE_IMPORTANCE_GAIN, C_API_FEATURE_IMPORTANCE_SPLIT, C_API_MATRIX_TYPE_CSC, - C_API_MATRIX_TYPE_CSR, C_API_PREDICT_CONTRIB, C_API_PREDICT_LEAF_INDEX, C_API_PREDICT_NORMAL, - C_API_PREDICT_RAW_SCORE, + LGBM_BoosterGetPredict, LGBM_BoosterLoadModelFromString, LGBM_BoosterNumModelPerIteration, + LGBM_BoosterNumberOfTotalModel, LGBM_BoosterPredictForMat, LGBM_BoosterSaveModel, + LGBM_BoosterSaveModelToString, LGBM_BoosterUpdateOneIter, C_API_FEATURE_IMPORTANCE_GAIN, + C_API_FEATURE_IMPORTANCE_SPLIT, C_API_MATRIX_TYPE_CSC, C_API_MATRIX_TYPE_CSR, + C_API_PREDICT_CONTRIB, C_API_PREDICT_LEAF_INDEX, C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE, }; use serde::{Deserialize, Serialize}; use std::{ @@ -320,6 +320,24 @@ impl Booster { } } + /// [LGBM_BoosterNumberOfTotalModel](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumberOfTotalModel) + pub fn number_of_total_model(&self) -> Result { + let mut value = 0; + unsafe { + to_result(LGBM_BoosterNumberOfTotalModel(self.handle, &mut value))?; + } + Ok(value as usize) + } + + /// [LGBM_BoosterNumModelPerIteration](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumModelPerIteration) + pub fn num_model_per_iteration(&self) -> Result { + let mut value = 0; + unsafe { + to_result(LGBM_BoosterNumModelPerIteration(self.handle, &mut value))?; + } + Ok(value as usize) + } + /// [LGBM_BoosterGetPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetPredict) #[doc(alias = "LGBM_BoosterGetPredict")] pub fn get_predict(&self, data_idx: usize) -> Result {