Skip to content

Commit

Permalink
[python] Allow python sklearn interface's fit() to pass init_model to…
Browse files Browse the repository at this point in the history
… train() (#2447)

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* Recognize LGBModel objects for init_model

* simplified condition

* updated docstring

* added test
  • Loading branch information
aaiyer authored and StrikerRUS committed Dec 5, 2019
1 parent 69c1c33 commit f3afe98
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
25 changes: 17 additions & 8 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def fit(self, X, y,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Build a gradient boosting model from the training set (X, y).
Parameters
Expand Down Expand Up @@ -442,6 +443,8 @@ def fit(self, X, y,
callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
init_model : string, Booster, LGBMModel or None, optional (default=None)
Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training.
Returns
-------
Expand Down Expand Up @@ -593,13 +596,16 @@ def _get_meta_data(collection, name, i):
valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set)

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_

self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=feval,
verbose_eval=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)

if evals_result:
self._evals_result = evals_result
Expand Down Expand Up @@ -731,7 +737,8 @@ def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None):
verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set,
Expand All @@ -742,7 +749,7 @@ def fit(self, X, y,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self

_base_doc = LGBMModel.fit.__doc__
Expand All @@ -758,7 +765,8 @@ def fit(self, X, y,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y)
Expand Down Expand Up @@ -804,7 +812,7 @@ def fit(self, X, y,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self

fit.__doc__ = LGBMModel.fit.__doc__
Expand Down Expand Up @@ -896,7 +904,8 @@ def fit(self, X, y,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=[1], early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
# check group data
if group is None:
Expand Down Expand Up @@ -924,7 +933,7 @@ def fit(self, X, y,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self

_base_doc = LGBMModel.fit.__doc__
Expand Down
13 changes: 13 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,16 @@ def test_class_weight(self):
for metric in gbm.evals_result_[eval_set]:
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
gbm_str.evals_result_[eval_set][metric])

def test_continue_training_with_model(self):
X, y = load_digits(3, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
init_gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False)
gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False, init_model=init_gbm)
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']),
len(gbm.evals_result_['valid_0']['multi_logloss']))
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']), 5)
self.assertLess(gbm.evals_result_['valid_0']['multi_logloss'][-1],
init_gbm.evals_result_['valid_0']['multi_logloss'][-1])

0 comments on commit f3afe98

Please sign in to comment.