Skip to content

Commit eb011fd

Browse files
ogrisellorentzenchrjeremiedbb
committed
Fix AttributeError use_fallback_lbfgs_solve for newton-cholesky when fitting with max_iter=0 (#26653)
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent ee89eb8 commit eb011fd

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

doc/whats_new/v1.3.rst

+5
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,11 @@ Changelog
455455
on linearly separable problems.
456456
:pr:`25214` by `Tom Dupre la Tour`_.
457457

458+
- |Fix| Fix a crash when calling `fit` on
459+
:class:`linear_model.LogisticRegression(solver="newton-cholesky", max_iter=0)`
460+
which failed to inspect the state of the model prior to the first parameter update.
461+
:pr:`26653` by :user:`Olivier Grisel <ogrisel>`.
462+
458463
- |API| Deprecates `n_iter` in favor of `max_iter` in
459464
:class:`linear_model.BayesianRidge` and :class:`linear_model.ARDRegression`.
460465
`n_iter` will be removed in scikit-learn 1.5. This change makes those

sklearn/linear_model/_glm/_newton_solver.py

+1
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def solve(self, X, y, sample_weight):
375375

376376
self.iteration = 1
377377
self.converged = False
378+
self.use_fallback_lbfgs_solve = False
378379

379380
while self.iteration <= self.max_iter and not self.converged:
380381
if self.verbose:

sklearn/linear_model/tests/test_logistic.py

+26
Original file line numberDiff line numberDiff line change
@@ -2063,3 +2063,29 @@ def test_liblinear_not_stuck():
20632063
with warnings.catch_warnings():
20642064
warnings.simplefilter("error", ConvergenceWarning)
20652065
clf.fit(X_prep, y)
2066+
2067+
2068+
@pytest.mark.parametrize("solver", SOLVERS)
2069+
def test_zero_max_iter(solver):
2070+
# Make sure we can inspect the state of LogisticRegression right after
2071+
# initialization (before the first weight update).
2072+
X, y = load_iris(return_X_y=True)
2073+
y = y == 2
2074+
with ignore_warnings(category=ConvergenceWarning):
2075+
clf = LogisticRegression(solver=solver, max_iter=0).fit(X, y)
2076+
if solver not in ["saga", "sag"]:
2077+
# XXX: sag and saga have n_iter_ = [1]...
2078+
assert clf.n_iter_ == 0
2079+
2080+
if solver != "lbfgs":
2081+
# XXX: lbfgs has already started to update the coefficients...
2082+
assert_allclose(clf.coef_, np.zeros_like(clf.coef_))
2083+
assert_allclose(
2084+
clf.decision_function(X),
2085+
np.full(shape=X.shape[0], fill_value=clf.intercept_),
2086+
)
2087+
assert_allclose(
2088+
clf.predict_proba(X),
2089+
np.full(shape=(X.shape[0], 2), fill_value=0.5),
2090+
)
2091+
assert clf.score(X, y) < 0.7

0 commit comments

Comments
 (0)