Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] add scikit-learn-style API for early stopping #5808

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ad43e17
Enable Auto Early Stopping
ClaudioSalvatoreArcidiacono Sep 15, 2023
f05e5e0
Relax test conditions
ClaudioSalvatoreArcidiacono Sep 18, 2023
76f3c19
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Sep 20, 2023
457c7f6
Merge master
ClaudioSalvatoreArcidiacono Jan 16, 2024
0db1941
Revert "Merge master"
ClaudioSalvatoreArcidiacono Jan 16, 2024
10fac65
Merge remote-tracking branch 'lgbm/master' into 3313-enable-auto-earl…
ClaudioSalvatoreArcidiacono Jan 16, 2024
d10ca54
Add missing import
ClaudioSalvatoreArcidiacono Jan 16, 2024
3b8eb0a
Remove added extra new line
ClaudioSalvatoreArcidiacono Jan 17, 2024
e47acc0
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 17, 2024
66701ac
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 25, 2024
39d333e
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 2, 2024
cad7eb6
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 6, 2024
1234ccf
Merge master
ClaudioSalvatoreArcidiacono Nov 28, 2024
d54c96a
Improve documentation, check default behavior of early stopping
ClaudioSalvatoreArcidiacono Nov 28, 2024
9c1c8b4
Solve python 3.8 compatibility issue
ClaudioSalvatoreArcidiacono Nov 28, 2024
724c7fe
Remove default to auto
ClaudioSalvatoreArcidiacono Nov 29, 2024
c957fce
Revert changes in fit top part
ClaudioSalvatoreArcidiacono Nov 29, 2024
2d7da78
Make interface as similar as possible to sklearn
ClaudioSalvatoreArcidiacono Nov 29, 2024
069a84e
Add parameters to dask interface
ClaudioSalvatoreArcidiacono Nov 29, 2024
c430ec1
Improve documentation
ClaudioSalvatoreArcidiacono Nov 29, 2024
416323a
Linting
ClaudioSalvatoreArcidiacono Nov 29, 2024
73562ff
Check for exact value equal true for early stopping
ClaudioSalvatoreArcidiacono Nov 29, 2024
38edc42
Merge branch 'master' into 3313-enable-auto-early-stopping
jameslamb Dec 15, 2024
9a32376
Switch if/else conditions order in fit
ClaudioSalvatoreArcidiacono Dec 18, 2024
f33ebd3
Merge remote-tracking branch 'origin/master' into 3313-enable-auto-ea…
ClaudioSalvatoreArcidiacono Dec 18, 2024
a61726f
fix issues in engine.py
ClaudioSalvatoreArcidiacono Dec 18, 2024
44316d7
make new early stopping parameters keyword-only
ClaudioSalvatoreArcidiacono Dec 18, 2024
4cbfc84
Remove n_iter_no_change parameter
ClaudioSalvatoreArcidiacono Dec 18, 2024
93acf6a
Address comments in tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
2b049c9
Improve tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
61371cb
Add tests to check for validation fraction
ClaudioSalvatoreArcidiacono Dec 18, 2024
65c4e2f
Remove validation_fraction=None option
ClaudioSalvatoreArcidiacono Dec 18, 2024
0a8e843
Remove validation_fraction=None option also in dask
ClaudioSalvatoreArcidiacono Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switch if/else conditions order in fit
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Dec 18, 2024
commit 9a32376856b85d8532923bb074455f41cf9e1157
100 changes: 50 additions & 50 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,56 @@ def fit(
params=params,
)

if self.early_stopping is True and eval_set is None:
valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)

elif self.early_stopping is True:
if self.validation_fraction is not None:
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
stratified = isinstance(self, LGBMClassifier)
Expand All @@ -1001,55 +1050,6 @@ def fit(
valid_set = train_set
valid_set = valid_set.construct()
valid_sets = [valid_set]
else:
valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
Expand Down
Loading