Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prediction early stopping #550

Merged
merged 8 commits into from
May 29, 2017
Merged

Conversation

cbecker
Copy link
Contributor

@cbecker cbecker commented May 26, 2017

Adds classification prediction early stopping and necessary API. It is based on comparisons on the classification margin while prediction is performed.

* \param earlyStop Early stopping instance
*/
virtual void PredictRawEarlyStop(const double* features, double* output,
const PredictionEarlyStopInstance& earlyStop) const = 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an early stopping struct instance is needed to predict a sample, and separates prediction early stopping from the configuration of GBDT, which I think helps keep the code an functionality separate.


/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible types are none, multiclass and binary.

std::vector<double> votes(static_cast<size_t>(sz));
for (int i=0; i < sz; ++i)
votes[i] = pred[i];
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we're not verifying that the prediction vector is at least of size 2, @guolinke would it be ok to check for it here and throw an exception if this is not met?

This means that PredictRawEarlyStop would throw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just use the margin of binary class when votes.size() == 1

@guolinke
Copy link
Collaborator

@cbecker
This is not a accuracy prediction, right? I mean it may drop the prediction accuracy.

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

This is not a accuracy prediction, right? I mean it may drop the prediction accuracy.

Yes, indeed. It's something that is useful when speed is a concern. Compared to limiting the numbe rof trees used for prediction, using this type of early stopping 'adapts' automatically to the training instance.

@guolinke
Copy link
Collaborator

guolinke commented May 26, 2017

OK. BTW, do we really need to expose the callback? Or just letting user to set the roundPeriod and marginThreshold is enough ?

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

OK. BTW, do we really need to expose the callback? Or just letting user to set the roundPeriod and marginThreshold is enough ?

I thought that exposing the callback is useful if the user wants to have a custom asymmetric early stopping method (e.g. if class 0 is background and 1 foreground, many place a threshold on the negative class score, and leave the positive one without early stopping. E.g. stop if binary score < -1.0)

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

Also, do you know why the tests may be failing? There is an undefined reference to PredictRawEarlyStop but it looks strange to me, and it compiles successfully on my desktop (linux, gcc)

�[0m[ 51%] �[32mBuilding CXX object CMakeFiles/lightgbm.dir/src/boosting/gbdt_prediction.cpp.o
�[0m�[31m�[1mLinking CXX executable ../lightgbm
�[0mCMakeFiles/lightgbm.dir/src/boosting/gbdt.cpp.o:(.data.rel.ro._ZTVN8LightGBM4GBDTE[_ZTVN8LightGBM4GBDTE]+0xa0): undefined reference to `LightGBM::GBDT::PredictRawEarlyStop(double const*, double*, LightGBM::PredictionEarlyStopInstance const&) const'
CMakeFiles/lightgbm.dir/src/boosting/boosting.cpp.o:(.data.rel.ro._ZTVN8LightGBM4DARTE[_ZTVN8LightGBM4DARTE]+0xa0): undefined reference to `LightGBM::GBDT::PredictRawEarlyStop(double const*, double*, LightGBM::PredictionEarlyStopInstance const&) const'
CMakeFiles/lightgbm.dir/src/boosting/boosting.cpp.o:(.data.rel.ro._ZTVN8LightGBM4GOSSE[_ZTVN8LightGBM4GOSSE]+0xa0): undefined reference to `LightGBM::GBDT::PredictRawEarlyStop(double const*, double*, LightGBM::PredictionEarlyStopInstance const&) const'
collect2: error: ld returned 1 exit status
make[2]: *** [../lightgbm] Error 1

@guolinke
Copy link
Collaborator

@cbecker refer to this pr #482 . It will replace the gbdt_prediction.cpp file. I think you may need to write your prediction code in a new cpp file.

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

@cbecker refer to this pr #482 . It will replace the gbdt_prediction.cpp file. I think you may need to write your prediction code in a new cpp file.

Thanks, that looks a bit tricky to me, the fact that it's overwriting its own code. I'll put it in a separate file, at least for now it will do, but it won't be able to deal with pre-compiled models. Btw, which speed up do you achieve by having the model in if/else c++ statements?

@wxchan
Copy link
Contributor

wxchan commented May 26, 2017

I think you can put your code in gbdt.cpp.

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

I think you can put your code in gbdt.cpp.

For me there or in the current file works too. Let me know what you prefer and I'll make the changes and rename the commit, if we can merge 👍

@cbecker
Copy link
Contributor Author

cbecker commented May 26, 2017

Done now :)

@guolinke
Copy link
Collaborator

@cbecker It seems this features cannot be easy to use. User must write his own cpp code to call this function ?

@cbecker
Copy link
Contributor Author

cbecker commented May 28, 2017

@cbecker It seems this features cannot be easy to use. User must write his own cpp code to call this function ?

True, I will add the respective C api tomorrow and get back to you.

@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

@guolinke would you agree with removing the parallel loops in PredictRaw() and Predict() now? Because then the code can be modularized and kept much simpler, and we'll have most of the prediction codein single function instead of 3 of them (including early stopping).

I can do this in a commit in this PR

@guolinke
Copy link
Collaborator

@cbecker
OK, you can make them to single thread.

@guolinke
Copy link
Collaborator

guolinke commented May 29, 2017

@cbecker
BTW, can you also support this for the CLI version?

@guolinke guolinke closed this May 29, 2017
@guolinke guolinke reopened this May 29, 2017
@cbecker cbecker force-pushed the earlyStopping branch 2 times, most recently from bd4ff80 to 1a8e807 Compare May 29, 2017 09:37
@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

Tests won't pass yet, I have to modify ModelToIfElse() first, but the overall functionality is there now.

@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

@guolinke I am wondering why the new two C apis are not being exported to the .so file, have you seen this before? Are you stripping symbols at some stage, and if so, how do you control which ones to keep?

EDIT: I think it may be that my definition is not in extern C. I will check.

src/c_api.cpp Outdated
@@ -162,6 +163,7 @@ class Booster {

void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
PredictionEarlyStoppingHandle early_stop_handle,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot use function_pointer in c_api. It is hard to use this in python/R

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I am wrong. it seems you use handle. So it is a class.

@guolinke
Copy link
Collaborator

guolinke commented May 29, 2017

Parameters
----------
early_stop_type: string
"none", "binary" or "multiclass". Regression is not supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use None instead of string 'none' here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, seem "none" is better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I allow for both options now.

@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

@guolinke Can we also support this in CLI version ? By passing the parameters when running LightGBM. And we can create an singleton Handle in CLI version, So that we don't need to create the instance of handle every time calling the prediction (https://github.com/cbecker/LightGBM/blob/5f50904dceec524451e17bae1ac7308ab180fc50/src/boosting/gbdt_prediction.cpp#L22-L26) .

I am not sure where to modify the code for that. I think it'd be safer if you do that, as I am not familiar with the CLI code. Let me know if you agree with the state of the current PR.

On my side I am getting a 2x speed up with very little loss in accuracy in one of my tests. I am classifying millions of samples, and I got down from 1 hour to 1/2 an hour :)

@wxchan
Copy link
Contributor

wxchan commented May 29, 2017

I think it's ready to merge for python and if-else part. You can add an example in examples/python-guide/advanced_example.py if you want.

@wxchan
Copy link
Contributor

wxchan commented May 29, 2017

by the way, you need to update windows folder for new files. it's easy to forget.

@guolinke
Copy link
Collaborator

@cbecker OK, i can do it.
BTW, I update your code directly for the parallel by samples.

@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

by the way, you need to update windows folder for new files. it's easy to forget.

I'm not sure what this means, is it something I have to do?

@guolinke
Copy link
Collaborator

@cbecker
I think wxchan's point is update the lightgbm.vcproj in the windows folder.

@wxchan
Copy link
Contributor

wxchan commented May 29, 2017

yes, LightGBM.vcxproj and LightGBM.vcxproj.filters

@cbecker
Copy link
Contributor Author

cbecker commented May 29, 2017

yes, LightGBM.vcxproj and LightGBM.vcxproj.filters

Thanks, I edited the file and made the changes manually as I don't have MSVC here, let me know if there are any issues.

@guolinke
Copy link
Collaborator

@wxchan Do you have any other comments ?

@wxchan
Copy link
Contributor

wxchan commented May 29, 2017

@guolinke no.

@guolinke guolinke merged commit 993bbd5 into microsoft:master May 29, 2017
@ddDragon
Copy link

ddDragon commented Jun 5, 2017

I used the early stopping parameter about a month ago, I found that didn't work? so this version is fixed the issue? or May be the method that i used was wrong?

@guolinke
Copy link
Collaborator

guolinke commented Jun 5, 2017

@ddDragon This is the early stopping for the prediction, not the training.

And I think the early stopping for training is always working. Maybe you use it by the wrong method.

@ddDragon
Copy link

ddDragon commented Jun 6, 2017

I used the parameters in lgb.train api named early_stopping_rounds. I set 30 rounds early stopping, total 5000 iterations training. I found the best valid set score was at about 2000 rounds. But all times the model trained all the 5000 rounds? This confused me a lot.

@guolinke
Copy link
Collaborator

guolinke commented Jun 6, 2017

@wxchan any idea about that ?
@ddDragon can you provide the re-produce script ?

@ddDragon
Copy link

ddDragon commented Jun 6, 2017

So sorry that I did't backup my code. Maybe something wrong in my code. If I meet this problem next time I would save it. Lgb is mush faster than xgb without loss accuracy. Very nice tool!

guolinke pushed a commit that referenced this pull request Oct 9, 2017
* Add early stopping for prediction

* Fix GBDT if-else prediction with early stopping

* Small C++ embelishments to early stopping API and functions

* Fix early stopping efficiency issue by creating a singleton for no early stopping

* Python improvements to early stopping API

* Add assertion check for binary and multiclass prediction score length

* Update vcxproj and vcxproj.filters with new early stopping files

* Remove inline from PredictRaw(), the linker was not able to find it otherwise
@guolinke guolinke mentioned this pull request Apr 1, 2018
@lock lock bot locked as resolved and limited conversation to collaborators Mar 12, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants