Skip to content

Commit

Permalink
fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Mar 22, 2017
1 parent e179c7c commit 2e962c7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_class_, num_used_model);
}
for (int i = 0; i < num_used_model; ++i) {
Expand All @@ -648,7 +649,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str();
}

std::string GBDT::SaveModelToString(int num_iterations) const {
std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss;

// output model type
Expand Down Expand Up @@ -676,8 +677,9 @@ std::string GBDT::SaveModelToString(int num_iterations) const {

ss << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iterations > 0) {
num_used_model = std::min(num_iterations * num_class_, num_used_model);
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_class_, num_used_model);
}
// output tree models
for (int i = 0; i < num_used_model; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GBDT: public Boosting {
*/
void RollbackOneIter() override;

int GetCurrentIteration() const override { return iter_ + num_init_iteration_; }
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_class_; }

bool EvalAndCheckEarlyStopping() override;

Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class template(object):
@staticmethod
def test_template(params={'objective': 'regression', 'metric': 'l2'},
X_y=load_boston(True), feval=mean_squared_error,
num_round=150, init_model=None, custom_eval=None,
num_round=200, init_model=None, custom_eval=None,
early_stopping_rounds=10,
return_data=False, return_model=False):
params['verbose'], params['seed'] = -1, 42
Expand Down

0 comments on commit 2e962c7

Please sign in to comment.