Skip to content

Commit

Permalink
Fix dmlc#3485, dmlc#3540: Don't use dropout for predicting test sets (d…
Browse files Browse the repository at this point in the history
…mlc#3556)

* Fix dmlc#3485, dmlc#3540: Don't use dropout for predicting test sets

Dropout (for DART) should only be used at training time.

* Add regression test
  • Loading branch information
hcho3 authored Aug 5, 2018
1 parent 109473d commit 44811f2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
3 changes: 3 additions & 0 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ class GradientBooster {
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
* \param out_preds output vector to hold the predictions
* \param dropout whether dropout should be applied to prediction
* This option is only meaningful if booster='dart'; otherwise ignored.
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
bool dropout = true,
unsigned ntree_limit = 0) = 0;
/*!
* \brief online prediction function, predict score for one instance at a time
Expand Down
1 change: 1 addition & 0 deletions src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class GBLinear : public GradientBooster {

void PredictBatch(DMatrix *p_fmat,
HostDeviceVector<bst_float> *out_preds,
bool dropout,
unsigned ntree_limit) override {
monitor_.Start("PredictBatch");
CHECK_EQ(ntree_limit, 0U)
Expand Down
6 changes: 5 additions & 1 deletion src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class GBTree : public GradientBooster {

void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
bool dropout,
unsigned ntree_limit) override {
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
Expand Down Expand Up @@ -356,8 +357,11 @@ class Dart : public GBTree {
// predict the leaf scores with dropout if ntree_limit = 0
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
bool dropout,
unsigned ntree_limit) override {
DropTrees(ntree_limit);
if (dropout) {
DropTrees(ntree_limit);
}
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
}

Expand Down
8 changes: 5 additions & 3 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class LearnerImpl : public Learner {
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, ntree_limit);
this->PredictRaw(data, out_preds, false, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
Expand Down Expand Up @@ -560,14 +560,16 @@ class LearnerImpl : public Learner {
* \brief get un-transformed prediction
* \param data training data matrix
* \param out_preds output vector that stores the prediction
* \param dropout whether dropout should be applied to prediction.
* This option is only meaningful if booster='dart'; otherwise ignored.
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
*/
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) const {
bool dropout = true, unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or InitModel";
gbm_->PredictBatch(data, out_preds, ntree_limit);
gbm_->PredictBatch(data, out_preds, dropout, ntree_limit);
}

// model parameter
Expand Down
7 changes: 7 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def test_dart(self):
preds2 = bst2.predict(dtest2, ntree_limit=num_round)
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
# regression test for issues #3485, #3540
for _ in range(10):
bst3 = xgb.Booster(params=param, model_file='xgb.model.dart')
dtest3 = xgb.DMatrix('dtest.buffer')
preds3 = bst3.predict(dtest3)
# assert they are the same
assert np.sum(np.abs(preds3 - preds)) == 0, 'preds3 = {}, preds = {}'.format(preds3, preds)

# check whether sample_type and normalize_type work
num_round = 50
Expand Down

0 comments on commit 44811f2

Please sign in to comment.