Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prediction early stopping #550

Merged
merged 8 commits into from
May 29, 2017
Prev Previous commit
Next Next commit
Fix GBDT if-else prediction with early stopping
  • Loading branch information
Carlos Becker committed May 29, 2017
commit 779074f8a20f3c07c67726c554ddb5f3d144b430
34 changes: 17 additions & 17 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,10 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;

str_buf << "#include \"gbdt.h\"" << std::endl;
str_buf << "#include <LightGBM/utils/openmp_wrapper.h>" << std::endl;
str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
str_buf << "#include <LightGBM/metric.h>" << std::endl;
str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
str_buf << "#include <ctime>" << std::endl;
str_buf << "#include <sstream>" << std::endl;
str_buf << "#include <chrono>" << std::endl;
Expand Down Expand Up @@ -738,32 +738,32 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {

std::stringstream pred_str_buf;

pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl;
pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl;
pred_str_buf << "\t" << "const auto noEarlyStop = createPredictionEarlyStopInstance(\"none\", PredictionEarlyStopConfig());" << std::endl;
pred_str_buf << "\t" << "if (earlyStop == nullptr) {" << std::endl;
pred_str_buf << "\t\t" << "earlyStop = &noEarlyStop;" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;

pred_str_buf << "\t" << "int earlyStopRoundCounter = 0;" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "} else {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl;
pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
pred_str_buf << "\t\t" << "++earlyStopRoundCounter;" << std::endl;
pred_str_buf << "\t\t" << "if (earlyStop->roundPeriod == earlyStopRoundCounter) {" << std::endl;
pred_str_buf << "\t\t\t" << "if (earlyStop->callbackFunction(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf << "\t\t\t" << "earlyStopRoundCounter = 0;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;

str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;

// Predict
str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
str_buf << "\t" << "PredictRaw(features, output, earlyStop);" << std::endl;
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
Expand Down