Skip to content

Commit

Permalink
Add number_of_total_model and num_model_per_iteration for Booster.
Browse files Browse the repository at this point in the history
  • Loading branch information
frozenlib committed Dec 28, 2023
1 parent 55d3b66 commit 8e57f19
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions lgbm/src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use lgbm_sys::{
LGBM_BoosterFree, LGBM_BoosterGetCurrentIteration, LGBM_BoosterGetEval,
LGBM_BoosterGetEvalCounts, LGBM_BoosterGetEvalNames, LGBM_BoosterGetFeatureNames,
LGBM_BoosterGetNumClasses, LGBM_BoosterGetNumFeature, LGBM_BoosterGetNumPredict,
LGBM_BoosterGetPredict, LGBM_BoosterLoadModelFromString, LGBM_BoosterPredictForMat,
LGBM_BoosterSaveModel, LGBM_BoosterSaveModelToString, LGBM_BoosterUpdateOneIter,
C_API_FEATURE_IMPORTANCE_GAIN, C_API_FEATURE_IMPORTANCE_SPLIT, C_API_MATRIX_TYPE_CSC,
C_API_MATRIX_TYPE_CSR, C_API_PREDICT_CONTRIB, C_API_PREDICT_LEAF_INDEX, C_API_PREDICT_NORMAL,
C_API_PREDICT_RAW_SCORE,
LGBM_BoosterGetPredict, LGBM_BoosterLoadModelFromString, LGBM_BoosterNumModelPerIteration,
LGBM_BoosterNumberOfTotalModel, LGBM_BoosterPredictForMat, LGBM_BoosterSaveModel,
LGBM_BoosterSaveModelToString, LGBM_BoosterUpdateOneIter, C_API_FEATURE_IMPORTANCE_GAIN,
C_API_FEATURE_IMPORTANCE_SPLIT, C_API_MATRIX_TYPE_CSC, C_API_MATRIX_TYPE_CSR,
C_API_PREDICT_CONTRIB, C_API_PREDICT_LEAF_INDEX, C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE,
};
use serde::{Deserialize, Serialize};
use std::{
Expand Down Expand Up @@ -320,6 +320,24 @@ impl Booster {
}
}

/// [LGBM_BoosterNumberOfTotalModel](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumberOfTotalModel)
pub fn number_of_total_model(&self) -> Result<usize> {
let mut value = 0;
unsafe {
to_result(LGBM_BoosterNumberOfTotalModel(self.handle, &mut value))?;
}
Ok(value as usize)
}

/// [LGBM_BoosterNumModelPerIteration](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumModelPerIteration)
pub fn num_model_per_iteration(&self) -> Result<usize> {
let mut value = 0;
unsafe {
to_result(LGBM_BoosterNumModelPerIteration(self.handle, &mut value))?;
}
Ok(value as usize)
}

/// [LGBM_BoosterGetPredict](https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterGetPredict)
#[doc(alias = "LGBM_BoosterGetPredict")]
pub fn get_predict(&self, data_idx: usize) -> Result<Prediction> {
Expand Down

0 comments on commit 8e57f19

Please sign in to comment.