Skip to content

Commit 81f2493

Browse files
committed
add ConvergenceWarning in do_line_search
1 parent 83c4628 commit 81f2493

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

celer/PN_logreg.pyx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,12 @@ cpdef void do_line_search(
400400
floating[::1, :] X, floating[:] X_data,
401401
int[:] X_indices, int[:] X_indptr, int MAX_BACKTRACK_ITR,
402402
floating[:] y, floating[:] exp_Xw, floating[:] low_exp_Xw,
403-
floating[:] aux, int[:] is_positive_label) nogil:
403+
floating[:] aux, int[:] is_positive_label):
404404

405405
cdef int i, ind, backtrack_itr
406406
cdef floating deriv
407407
cdef floating step_size = 1.
408+
cdef floating atol = 1e-7
408409

409410
cdef int n_samples = y.shape[0]
410411
fcopy(&n_samples, &exp_Xw[0], &inc, &low_exp_Xw[0], &inc)
@@ -417,15 +418,18 @@ cpdef void do_line_search(
417418
deriv = compute_derivative(
418419
w, WS, delta_w, X_delta_w, alpha, aux, step_size, y)
419420

420-
if deriv < 1e-7:
421+
if deriv < atol:
421422
break
422423
else:
423424
step_size = step_size / 2.
424425
for i in range(n_samples):
425426
exp_Xw[i] = sqrt(exp_Xw[i] * low_exp_Xw[i])
426427
else:
427-
pass
428-
# TODO what do we do in this case?
428+
warnings.warn(
429+
'Line search failed to converge '
430+
f'deriv {deriv:.2e}, atol {atol:.2e}',
431+
ConvergenceWarning
432+
)
429433

430434
# a suitable step size is found, perform step:
431435
for ind in range(WS.shape[0]):

0 commit comments

Comments
 (0)