Skip to content

Commit

Permalink
Add early stopping for prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Becker committed May 26, 2017
1 parent 3abff37 commit 145bb44
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace LightGBM {
class Dataset;
class ObjectiveFunction;
class Metric;
class PredictionEarlyStopInstance;

/*!
* \brief The interface for Boosting
Expand Down Expand Up @@ -134,6 +135,17 @@ class LIGHTGBM_EXPORT Boosting {
virtual void PredictLeafIndex(
const double* features, double* output) const = 0;

/*!
* \brief Prediction for one record, similar to PredictRaw but applies early stopping to speed up prediction.
* In contrast to PredictRaw(), this function runs on a single thread.
*
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param earlyStop Early stopping instance
*/
virtual void PredictRawEarlyStop(const double* features, double* output,
const PredictionEarlyStopInstance& earlyStop) const = 0;

/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
Expand Down
34 changes: 34 additions & 0 deletions include/LightGBM/prediction_early_stop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef LIGHTGBM_PREDICTION_EARLY_STOP_H_
#define LIGHTGBM_PREDICTION_EARLY_STOP_H_

#include <functional>
#include <string>

#include <LightGBM/export.h>

namespace LightGBM
{
struct PredictionEarlyStopInstance
{
/// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>;

FunctionType callbackFunction; // callback function itself
int roundPeriod; // call callbackFunction every `runPeriod` iterations
};

struct PredictionEarlyStopConfig
{
int roundPeriod;
double marginThreshold;
};

/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);

} // namespace LightGBM

#endif // LIGHTGBM_PREDICTION_EARLY_STOP_H_
26 changes: 26 additions & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>

#include <ctime>

Expand Down Expand Up @@ -1003,4 +1004,29 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
return pairs;
}

void GBDT::PredictRawEarlyStop(const double* features, double* output,
const PredictionEarlyStopInstance& earlyStop) const
{
if (earlyStop.roundPeriod < 1)
{
throw std::runtime_error("Tried to use early stopping with roundPeriod less than one");
}

int roundCounter = 0;
for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features);
}

// check early stopping
++roundCounter;
if (earlyStop.roundPeriod == roundCounter) {
if (earlyStop.callbackFunction(output, num_tree_per_iteration_))
return;
roundCounter = 0;
}
}
}

} // namespace LightGBM
3 changes: 3 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class GBDT: public Boosting {
return num_preb_in_one_row;
}

void PredictRawEarlyStop(const double* features, double* output,
const PredictionEarlyStopInstance& earlyStop) const override;

void PredictRaw(const double* features, double* output) const override;

void Predict(const double* features, double* output) const override;
Expand Down
3 changes: 2 additions & 1 deletion src/boosting/gbdt_prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>

#include <ctime>

Expand Down Expand Up @@ -68,4 +69,4 @@ void GBDT::PredictLeafIndex(const double* features, double* output) const {
}
}

} // namespace LightGBM
} // namespace LightGBM
88 changes: 88 additions & 0 deletions src/boosting/prediction_early_stop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <LightGBM/prediction_early_stop.h>

using namespace LightGBM;

#include <algorithm>
#include <vector>
#include <cmath>

namespace
{
PredictionEarlyStopInstance createNone(const PredictionEarlyStopConfig&)
{
return PredictionEarlyStopInstance{
[](const double*, int)
{
return false;
}, 100000000 // arbitrary number high enough so that the function will rarely be called
};
}

PredictionEarlyStopInstance createMulticlass(const PredictionEarlyStopConfig& config)
{
// marginThreshold will be captured by value
const double marginThreshold = config.marginThreshold;

return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int sz)
{
// copy and sort
std::vector<double> votes(static_cast<size_t>(sz));
for (int i=0; i < sz; ++i)
votes[i] = pred[i];
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());

const auto margin = votes[0] - votes[1];

if (margin > marginThreshold)
return true;

return false;
},
config.roundPeriod
};
}

PredictionEarlyStopInstance createBinary(const PredictionEarlyStopConfig& config)
{
// marginThreshold will be captured by value
const double marginThreshold = config.marginThreshold;

return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int)
{
const auto margin = 2.0 * fabs(pred[0]);

if (margin > marginThreshold)
return true;

return false;
},
config.roundPeriod
};
}
}

namespace LightGBM
{
PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config)
{
if (type == "none")
{
return createNone(config);
}
else if (type == "multiclass")
{
return createMulticlass(config);
}
else if (type == "binary")
{
return createBinary(config);
}
else
{
throw std::runtime_error("Unknown early stopping type: " + type);
}
}
}

0 comments on commit 145bb44

Please sign in to comment.