Skip to content

Commit 41d0645

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Added reset method to StoppingCriterion (#2927)
Summary: Pull Request resolved: #2927 A custom StoppingCriterion currently does not work as intended when multiple restarts are used in fit_gpytorch_mll_(torch). Due to its preservation of state, the previous fitting attempt's data will be saved, and generally, the criterion will immediately be satisfied on the new attempt, thus stopping fitting on the first iteration. StoppingCriterion is now a Protocol with methods `__call__` and `reset`. Reviewed By: Balandat Differential Revision: D78343624 fbshipit-source-id: 6897a9c3f50f47fa7db4e9617e51438afd56d7b8
1 parent 223656a commit 41d0645

File tree

7 files changed

+53
-29
lines changed

7 files changed

+53
-29
lines changed

botorch/generation/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def assign_grad():
601601
return loss
602602

603603
_optimizer.step(assign_grad)
604-
stop = stopping_criterion.evaluate(fvals=loss.detach())
604+
stop = stopping_criterion(fvals=loss.detach())
605605
if timeout_sec is not None:
606606
runtime = time.monotonic() - start_time
607607
if runtime > timeout_sec:

botorch/optim/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy.typing as npt
2121

2222
from botorch.optim.closures import NdarrayOptimizationClosure
23+
from botorch.optim.stopping import StoppingCriterion
2324
from botorch.optim.utils.numpy_utils import get_bounds_as_ndarray
2425
from botorch.optim.utils.timeout import minimize_with_timeout
2526
from numpy import asarray, float64 as np_float64
@@ -153,7 +154,7 @@ def torch_minimize(
153154
scheduler: LRScheduler | Callable[[Optimizer], LRScheduler] | None = None,
154155
step_limit: int | None = None,
155156
timeout_sec: float | None = None,
156-
stopping_criterion: Callable[[Tensor], bool] | None = None,
157+
stopping_criterion: StoppingCriterion | None = None,
157158
) -> OptimizationResult:
158159
r"""Generic torch.optim-based optimization routine.
159160
@@ -190,6 +191,10 @@ def torch_minimize(
190191
if not (scheduler is None or isinstance(scheduler, LRScheduler)):
191192
scheduler = scheduler(optimizer)
192193

194+
if stopping_criterion is not None:
195+
# Reset stopping criterion to ensure clean state for new optimization run
196+
stopping_criterion.reset()
197+
193198
_bounds = (
194199
{}
195200
if bounds is None

botorch/optim/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
scipy_minimize,
2323
torch_minimize,
2424
)
25-
from botorch.optim.stopping import ExpMAStoppingCriterion
25+
from botorch.optim.stopping import ExpMAStoppingCriterion, StoppingCriterion
2626
from botorch.optim.utils import get_parameters_and_bounds, TorchAttr
2727
from botorch.utils.types import DEFAULT
2828
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
@@ -118,7 +118,7 @@ def fit_gpytorch_mll_torch(
118118
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
119119
closure_kwargs: dict[str, Any] | None = None,
120120
step_limit: int | None = None,
121-
stopping_criterion: Callable[[Tensor], bool] | None = DEFAULT, # pyre-ignore [9]
121+
stopping_criterion: StoppingCriterion | None = DEFAULT,
122122
optimizer: Optimizer | Callable[..., Optimizer] = Adam,
123123
scheduler: _LRScheduler | Callable[..., _LRScheduler] | None = None,
124124
callback: Callable[[dict[str, Tensor], OptimizationResult], None] | None = None,

botorch/optim/optimize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def optimize_acqf_cyclic(
826826
if q > 1:
827827
cyclic_options = cyclic_options or {}
828828
stopping_criterion = ExpMAStoppingCriterion(**cyclic_options)
829-
stop = stopping_criterion.evaluate(fvals=acq_vals)
829+
stop = stopping_criterion(fvals=acq_vals)
830830
base_X_pending = acq_function.X_pending
831831
idxr = torch.ones(q, dtype=torch.bool, device=opt_inputs.bounds.device)
832832
while not stop:
@@ -847,7 +847,7 @@ def optimize_acqf_cyclic(
847847
candidates[i] = candidate_i
848848
acq_vals[i] = acq_val_i
849849
idxr[i] = 1
850-
stop = stopping_criterion.evaluate(fvals=acq_vals)
850+
stop = stopping_criterion(fvals=acq_vals)
851851
acq_function.set_X_pending(base_X_pending)
852852
return candidates, acq_vals
853853

botorch/optim/stopping.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,20 @@
66

77
from __future__ import annotations
88

9-
from abc import ABC, abstractmethod
9+
from typing import Protocol
1010

1111
import torch
1212
from torch import Tensor
1313

1414

15-
class StoppingCriterion(ABC):
16-
r"""Base class for evaluating optimization convergence.
15+
class StoppingCriterion(Protocol):
16+
r"""Protocol for evaluating optimization convergence.
1717
18-
Stopping criteria are implemented as a objects rather than a function, so that they
18+
Stopping criteria are implemented as objects rather than functions, so that they
1919
can keep track of past function values between optimization steps.
2020
"""
2121

22-
@abstractmethod
23-
def evaluate(self, fvals: Tensor) -> bool:
22+
def __call__(self, fvals: Tensor) -> bool:
2423
r"""Evaluate the stopping criterion.
2524
2625
Args:
@@ -30,15 +29,20 @@ def evaluate(self, fvals: Tensor) -> bool:
3029
true for all elements.
3130
3231
Returns:
33-
Stopping indicator (if True, stop the optimziation).
32+
Stopping indicator (if True, stop the optimization).
3433
"""
35-
pass # pragma: no cover
34+
... # pragma: no cover
3635

37-
def __call__(self, fvals: Tensor) -> bool:
38-
return self.evaluate(fvals)
36+
def reset(self) -> None:
37+
r"""Reset the stopping criterion to its initial state.
38+
39+
This method should be called before starting a new optimization run
40+
to ensure that any internal state from previous runs is cleared.
41+
"""
42+
... # pragma: no cover
3943

4044

41-
class ExpMAStoppingCriterion(StoppingCriterion):
45+
class ExpMAStoppingCriterion:
4246
r"""Exponential moving average stopping criterion.
4347
4448
Computes an exponentially weighted moving average over window length `n_window`
@@ -80,7 +84,7 @@ def __init__(
8084
self.weights = weights / weights.sum()
8185
self._prev_fvals = None
8286

83-
def evaluate(self, fvals: Tensor) -> bool:
87+
def __call__(self, fvals: Tensor) -> bool:
8488
r"""Evaluate the stopping criterion.
8589
8690
Args:
@@ -125,3 +129,11 @@ def evaluate(self, fvals: Tensor) -> bool:
125129
return True
126130

127131
return False
132+
133+
def reset(self) -> None:
134+
r"""Reset the stopping criterion to its initial state.
135+
136+
Resets the iteration counter and clears any stored function values.
137+
"""
138+
self.iter = 0
139+
self._prev_fvals = None

test/optim/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
scipy_minimize,
1818
torch_minimize,
1919
)
20+
from botorch.optim.stopping import ExpMAStoppingCriterion
2021
from botorch.utils.testing import BotorchTestCase
2122
from numpy import allclose
2223
from scipy.optimize import OptimizeResult
@@ -254,11 +255,11 @@ def _callback(parameters, result, out) -> None:
254255
self.assertEqual(result.step, len(step_results))
255256

256257
# Test `stopping_criterion`
257-
stopping_decisions = iter((False, False, True, False))
258+
max3_stopping_criterion = ExpMAStoppingCriterion(maxiter=3, n_window=5)
258259
result = torch_minimize(
259260
closure=closure,
260261
parameters=closure.parameters,
261-
stopping_criterion=lambda fval: next(stopping_decisions),
262+
stopping_criterion=max3_stopping_criterion,
262263
)
263264
self.assertEqual(result.step, 3)
264265
self.assertEqual(result.status, OptimizationStatus.STOPPED)

test/optim/test_stopping.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@
77
from __future__ import annotations
88

99
import torch
10-
from botorch.optim.stopping import ExpMAStoppingCriterion, StoppingCriterion
10+
from botorch.optim.stopping import ExpMAStoppingCriterion
1111
from botorch.utils.testing import BotorchTestCase
1212

1313

1414
class TestStoppingCriterion(BotorchTestCase):
15-
def test_abstract_raises(self):
16-
with self.assertRaises(TypeError):
17-
StoppingCriterion()
18-
1915
def test_exponential_moving_average(self):
2016
for dtype in (torch.float, torch.double):
2117
tkwargs = {"device": self.device, "dtype": dtype}
@@ -25,8 +21,8 @@ def test_exponential_moving_average(self):
2521
self.assertEqual(sc.maxiter, 2)
2622
self.assertEqual(sc.n_window, 10)
2723
self.assertEqual(sc.rel_tol, 1e-5)
28-
self.assertFalse(sc.evaluate(fvals=torch.ones(1, **tkwargs)))
29-
self.assertTrue(sc.evaluate(fvals=torch.zeros(1, **tkwargs)))
24+
self.assertFalse(sc(fvals=torch.ones(1, **tkwargs)))
25+
self.assertTrue(sc(fvals=torch.zeros(1, **tkwargs)))
3026

3127
# test convergence
3228
n_window = 4
@@ -43,7 +39,7 @@ def test_exponential_moving_average(self):
4339
if not minimize:
4440
f_vals = -f_vals
4541
for i, fval in enumerate(f_vals):
46-
if sc.evaluate(fval):
42+
if sc(fval):
4743
self.assertEqual(i, 10)
4844
break
4945
# test multiple components
@@ -55,6 +51,16 @@ def test_exponential_moving_average(self):
5551
df = -df
5652
f_vals = torch.stack([f_vals, f_vals + df], dim=-1)
5753
for i, fval in enumerate(f_vals):
58-
if sc.evaluate(fval):
54+
if sc(fval):
5955
self.assertEqual(i, 10)
6056
break
57+
58+
# Test reset functionality - verify state after use, reset, and reuse
59+
self.assertGreater(sc.iter, 0)
60+
self.assertIsNotNone(sc._prev_fvals)
61+
sc.reset()
62+
self.assertEqual(sc.iter, 0)
63+
self.assertIsNone(sc._prev_fvals)
64+
# Verify criterion works after reset
65+
self.assertFalse(sc(f_vals[0]))
66+
self.assertEqual(sc.iter, 1)

0 commit comments

Comments
 (0)