-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prediction early stopping #550
Conversation
include/LightGBM/boosting.h
Outdated
* \param earlyStop Early stopping instance | ||
*/ | ||
virtual void PredictRawEarlyStop(const double* features, double* output, | ||
const PredictionEarlyStopInstance& earlyStop) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an early stopping struct instance is needed to predict a sample, and separates prediction early stopping from the configuration of GBDT, which I think helps keep the code an functionality separate.
|
||
/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold | ||
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type, | ||
const PredictionEarlyStopConfig& config); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible types are none
, multiclass
and binary
.
std::vector<double> votes(static_cast<size_t>(sz)); | ||
for (int i=0; i < sz; ++i) | ||
votes[i] = pred[i]; | ||
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now we're not verifying that the prediction vector is at least of size 2, @guolinke would it be ok to check for it here and throw an exception if this is not met?
This means that PredictRawEarlyStop
would throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can just use the margin of binary class when votes.size() == 1
@cbecker |
Yes, indeed. It's something that is useful when speed is a concern. Compared to limiting the numbe rof trees used for prediction, using this type of early stopping 'adapts' automatically to the training instance. |
OK. BTW, do we really need to expose the callback? Or just letting user to set the |
I thought that exposing the callback is useful if the user wants to have a custom asymmetric early stopping method (e.g. if class 0 is background and 1 foreground, many place a threshold on the negative class score, and leave the positive one without early stopping. E.g. stop if binary score < -1.0) |
Also, do you know why the tests may be failing? There is an undefined reference to
|
Thanks, that looks a bit tricky to me, the fact that it's overwriting its own code. I'll put it in a separate file, at least for now it will do, but it won't be able to deal with pre-compiled models. Btw, which speed up do you achieve by having the model in if/else c++ statements? |
I think you can put your code in gbdt.cpp. |
For me there or in the current file works too. Let me know what you prefer and I'll make the changes and rename the commit, if we can merge 👍 |
Done now :) |
@cbecker It seems this features cannot be easy to use. User must write his own cpp code to call this function ? |
True, I will add the respective C api tomorrow and get back to you. |
@guolinke would you agree with removing the parallel loops in PredictRaw() and Predict() now? Because then the code can be modularized and kept much simpler, and we'll have most of the prediction codein single function instead of 3 of them (including early stopping). I can do this in a commit in this PR |
@cbecker |
@cbecker |
bd4ff80
to
1a8e807
Compare
Tests won't pass yet, I have to modify |
@guolinke I am wondering why the new two C apis are not being exported to the .so file, have you seen this before? Are you stripping symbols at some stage, and if so, how do you control which ones to keep? EDIT: I think it may be that my definition is not in extern C. I will check. |
src/c_api.cpp
Outdated
@@ -162,6 +163,7 @@ class Booster { | |||
|
|||
void Predict(int num_iteration, int predict_type, int nrow, | |||
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, | |||
PredictionEarlyStoppingHandle early_stop_handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot use function_pointer in c_api. It is hard to use this in python/R
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I am wrong. it seems you use handle. So it is a class.
@cbecker I think you should these two apis outside of this class. |
python-package/lightgbm/basic.py
Outdated
Parameters | ||
---------- | ||
early_stop_type: string | ||
"none", "binary" or "multiclass". Regression is not supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just use None
instead of string 'none' here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, seem "none" is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I allow for both options now.
I am not sure where to modify the code for that. I think it'd be safer if you do that, as I am not familiar with the CLI code. Let me know if you agree with the state of the current PR. On my side I am getting a 2x speed up with very little loss in accuracy in one of my tests. I am classifying millions of samples, and I got down from 1 hour to 1/2 an hour :) |
I think it's ready to merge for python and if-else part. You can add an example in examples/python-guide/advanced_example.py if you want. |
by the way, you need to update |
@cbecker OK, i can do it. |
I'm not sure what this means, is it something I have to do? |
@cbecker |
yes, LightGBM.vcxproj and LightGBM.vcxproj.filters |
Thanks, I edited the file and made the changes manually as I don't have MSVC here, let me know if there are any issues. |
@wxchan Do you have any other comments ? |
@guolinke no. |
I used the early stopping parameter about a month ago, I found that didn't work? so this version is fixed the issue? or May be the method that i used was wrong? |
@ddDragon This is the early stopping for the prediction, not the training. And I think the early stopping for training is always working. Maybe you use it by the wrong method. |
I used the parameters in lgb.train api named early_stopping_rounds. I set 30 rounds early stopping, total 5000 iterations training. I found the best valid set score was at about 2000 rounds. But all times the model trained all the 5000 rounds? This confused me a lot. |
So sorry that I did't backup my code. Maybe something wrong in my code. If I meet this problem next time I would save it. Lgb is mush faster than xgb without loss accuracy. Very nice tool! |
* Add early stopping for prediction * Fix GBDT if-else prediction with early stopping * Small C++ embelishments to early stopping API and functions * Fix early stopping efficiency issue by creating a singleton for no early stopping * Python improvements to early stopping API * Add assertion check for binary and multiclass prediction score length * Update vcxproj and vcxproj.filters with new early stopping files * Remove inline from PredictRaw(), the linker was not able to find it otherwise
Adds classification prediction early stopping and necessary API. It is based on comparisons on the classification margin while prediction is performed.