Skip to content

Commit

Permalink
Add capability to get possible max and min values for a model
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored and joanfontanals committed Feb 3, 2020
1 parent 8653098 commit 971720f
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;

/*!
* \brief Calculate max possible value
* \return max possible value
*/
virtual double GetMaxValue() const = 0;

/*!
* \brief Calculate min possible value
* \return min possible value
*/
virtual double GetMinValue() const = 0;

/*!
* \brief Get max feature index of this model
* \return Max feature index of this model
Expand Down
18 changes: 18 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,24 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int importance_type,
double* out_results);

/*!
* \brief Get model max possible value.
* \param handle Handle of booster
* \param[out] out_results Result pointer to max value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetMaxValue(BoosterHandle handle,
double* out_results);

/*!
* \brief Get model min possible value.
* \param handle Handle of booster
* \param[out] out_results Result pointing to min value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetMinValue(BoosterHandle handle,
double* out_results);

/*!
* \brief Initialize the network.
* \param machines List of machines in format 'ip1:port1,ip2:port2'
Expand Down
12 changes: 12 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ class Tree {
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const;

/*!
* \brief Adding max leaf value of this tree model to score
* \param score Will add value to score
*/
void AddMaxValueToScore(double& score) const;

/*!
* \brief Adding min leaf value of this tree model to score
* \param score Will add valu to score
*/
void AddMinValueToScore(double& score) const;

/*!
* \brief Prediction on one record
* \param feature_values Feature value of this record
Expand Down
12 changes: 12 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ class GBDT : public GBDTBase {
*/
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const override;

/*!
* \brief Calculate max possible value
* \return max possible value
*/
double GetMaxValue() const override;

/*!
* \brief Calculate min possible value
* \return min possible value
*/
double GetMinValue() const override;

/*!
* \brief Get max feature index of this model
* \return Max feature index of this model
Expand Down
18 changes: 18 additions & 0 deletions src/boosting/gbdt_prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,22 @@ void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features
}
}

double GBDT::GetMaxValue() const {
double max_value = 0.0;
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
models_[i]->AddMaxValueToScore(max_value);
}
return max_value;
}

double GBDT::GetMinValue() const {
double min_value = 0.0;
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
models_[i]->AddMinValueToScore(min_value);
}
return min_value;
}

} // namespace LightGBM
26 changes: 26 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ class Booster {
return boosting_->FeatureImportance(num_iteration, importance_type);
}

double MaxValue() {
return boosting_->GetMaxValue();
}

double MinValue() {
return boosting_->GetMinValue();
}

double GetLeafValue(int tree_idx, int leaf_idx) const {
return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
Expand Down Expand Up @@ -1583,6 +1591,24 @@ int LGBM_BoosterFeatureImportance(BoosterHandle handle,
API_END();
}

int LGBM_BoosterGetMaxValue(BoosterHandle handle,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
double max_value = ref_booster->MaxValue();
out_results[0] = max_value;
API_END();
}

int LGBM_BoosterGetMinValue(BoosterHandle handle,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
double min_value = ref_booster->MinValue();
out_results[0] = min_value;
API_END();
}

int LGBM_NetworkInit(const char* machines,
int local_listen_port,
int listen_time_out,
Expand Down
8 changes: 8 additions & 0 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ void Tree::AddPredictionToScore(const Dataset* data,
}
}

void Tree::AddMaxValueToScore(double& score) const {
score += *std::max_element(leaf_value_.begin(), leaf_value_.end());
}

void Tree::AddMinValueToScore(double& score) const {
score += *std::min_element(leaf_value_.begin(), leaf_value_.end());
}

#undef PredictionFun

std::string Tree::ToString() const {
Expand Down

0 comments on commit 971720f

Please sign in to comment.