Skip to content

Commit 61cb668

Browse files
authored
Merge c728104 into ad38736
2 parents ad38736 + c728104 commit 61cb668

File tree

3 files changed

+131
-45
lines changed

3 files changed

+131
-45
lines changed

botorch/optim/initializers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818
from math import ceil
19-
from typing import Dict, List, Optional, Tuple, Union
19+
from typing import Callable, Dict, List, Optional, Tuple, Union
2020

2121
import torch
2222
from botorch import settings
@@ -46,6 +46,22 @@
4646
from torch.distributions import Normal
4747
from torch.quasirandom import SobolEngine
4848

49+
TGenInitialConditions = Callable[
50+
[
51+
# reasoning behind this annotation: contravariance
52+
qKnowledgeGradient,
53+
Tensor,
54+
int,
55+
int,
56+
int,
57+
Optional[Dict[int, float]],
58+
Optional[Dict[str, Union[bool, float, int]]],
59+
Optional[List[Tuple[Tensor, Tensor, float]]],
60+
Optional[List[Tuple[Tensor, Tensor, float]]],
61+
],
62+
Optional[Tensor],
63+
]
64+
4965

5066
def gen_batch_initial_conditions(
5167
acq_function: AcquisitionFunction,

botorch/optim/optimize.py

Lines changed: 103 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from botorch.optim.initializers import (
3131
gen_batch_initial_conditions,
3232
gen_one_shot_kg_initial_conditions,
33+
TGenInitialConditions,
3334
)
3435
from botorch.optim.stopping import ExpMAStoppingCriterion
3536
from botorch.optim.utils import _filter_kwargs
@@ -53,7 +54,7 @@
5354
}
5455

5556

56-
@dataclasses.dataclass
57+
@dataclasses.dataclass(frozen=True)
5758
class OptimizeAcqfInputs:
5859
"""
5960
Container for inputs to `optimize_acqf`.
@@ -76,10 +77,19 @@ class OptimizeAcqfInputs:
7677
return_best_only: bool
7778
gen_candidates: TGenCandidates
7879
sequential: bool
79-
kwargs: Dict[str, Any]
80-
ic_generator: Callable = dataclasses.field(init=False)
80+
ic_generator: Optional[TGenInitialConditions] = None
81+
timeout_sec: Optional[float] = None
82+
return_full_tree: bool = False
83+
ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict)
84+
85+
@property
86+
def full_tree(self) -> bool:
87+
return (
88+
isinstance(self.acq_function, OneShotAcquisitionFunction)
89+
and not self.return_full_tree
90+
)
8191

82-
def _validate(self) -> None:
92+
def __post_init__(self) -> None:
8393
if self.inequality_constraints is None and not (
8494
self.bounds.ndim == 2 and self.bounds.shape[0] == 2
8595
):
@@ -114,7 +124,7 @@ def _validate(self) -> None:
114124
f"shape is {batch_initial_conditions_shape}."
115125
)
116126

117-
elif "ic_generator" not in self.kwargs.keys():
127+
elif not self.ic_generator:
118128
if self.nonlinear_inequality_constraints:
119129
raise RuntimeError(
120130
"`ic_generator` must be given if "
@@ -137,14 +147,31 @@ def _validate(self) -> None:
137147
"acquisition functions. Must have `sequential=False`."
138148
)
139149

140-
def __post_init__(self) -> None:
141-
self._validate()
142-
if "ic_generator" in self.kwargs.keys():
143-
self.ic_generator = self.kwargs.pop("ic_generator")
150+
@property
151+
def ic_gen(self) -> TGenInitialConditions:
152+
if self.ic_generator:
153+
return self.ic_generator
144154
elif isinstance(self.acq_function, qKnowledgeGradient):
145-
self.ic_generator = gen_one_shot_kg_initial_conditions
146-
else:
147-
self.ic_generator = gen_batch_initial_conditions
155+
return gen_one_shot_kg_initial_conditions
156+
return gen_batch_initial_conditions
157+
158+
159+
def _raise_deprecation_warning_if_kwargs(fn_name: str, kwargs: Dict[str, Any]) -> None:
160+
"""
161+
Raise a warning if kwargs are provided.
162+
163+
Some functions used to support **kwargs. The applicable parameters have now been
164+
refactored to be named arguments, so no warning will be raised for users passing
165+
the expected arguments. However, if a user had been passing an inapplicable
166+
keyword argument, this will now raise a warning whereas in the past it did
167+
nothing.
168+
"""
169+
if len(kwargs) > 0:
170+
warnings.warn(
171+
f"`{fn_name}` does not support arguments {list(kwargs.keys())}. In "
172+
"the future, this will become an error.",
173+
DeprecationWarning,
174+
)
148175

149176

150177
def _optimize_acqf_all_features_fixed(
@@ -170,34 +197,32 @@ def _optimize_acqf_all_features_fixed(
170197

171198

172199
def _optimize_acqf_sequential_q(
173-
opt_inputs: OptimizeAcqfInputs,
174-
timeout_sec: Optional[float],
175-
start_time: float,
200+
opt_inputs: OptimizeAcqfInputs, timeout_sec: Optional[float], start_time: float
176201
) -> Tuple[Tensor, Tensor]:
177202
"""
178203
Helper function for `optimize_acqf` when sequential=True and q > 1.
179204
"""
180-
kwargs = opt_inputs.kwargs or {}
181205
if timeout_sec is not None:
182206
# When using sequential optimization, we allocate the total timeout
183207
# evenly across the individual acquisition optimizations.
184208
timeout_sec = (timeout_sec - start_time) / opt_inputs.q
185-
kwargs["timeout_sec"] = timeout_sec
186209

187210
candidate_list, acq_value_list = [], []
188211
base_X_pending = opt_inputs.acq_function.X_pending
189212

213+
new_inputs = dataclasses.replace(
214+
opt_inputs,
215+
q=1,
216+
batch_initial_conditions=None,
217+
return_best_only=True,
218+
sequential=False,
219+
timeout_sec=timeout_sec,
220+
)
190221
for i in range(opt_inputs.q):
191-
kwargs["ic_generator"] = opt_inputs.ic_generator
192-
new_inputs = dataclasses.replace(
193-
opt_inputs,
194-
q=1,
195-
batch_initial_conditions=None,
196-
return_best_only=True,
197-
sequential=False,
198-
kwargs=kwargs,
222+
223+
candidate, acq_value = _optimize_acqf_batch(
224+
new_inputs, start_time=start_time, timeout_sec=timeout_sec
199225
)
200-
candidate, acq_value = _optimize_acqf(new_inputs)
201226

202227
candidate_list.append(candidate)
203228
acq_value_list.append(acq_value)
@@ -217,17 +242,13 @@ def _optimize_acqf_batch(
217242
) -> Tuple[Tensor, Tensor]:
218243
options = opt_inputs.options or {}
219244

220-
kwargs = opt_inputs.kwargs
221-
full_tree = isinstance(
222-
opt_inputs.acq_function, OneShotAcquisitionFunction
223-
) and not kwargs.pop("return_full_tree", False)
224-
225245
initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
226246

227247
if initial_conditions_provided:
228248
batch_initial_conditions = opt_inputs.batch_initial_conditions
229249
else:
230-
batch_initial_conditions = opt_inputs.ic_generator(
250+
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
251+
batch_initial_conditions = opt_inputs.ic_gen(
231252
acq_function=opt_inputs.acq_function,
232253
bounds=opt_inputs.bounds,
233254
q=opt_inputs.q,
@@ -237,7 +258,7 @@ def _optimize_acqf_batch(
237258
options=options,
238259
inequality_constraints=opt_inputs.inequality_constraints,
239260
equality_constraints=opt_inputs.equality_constraints,
240-
**kwargs,
261+
**opt_inputs.ic_gen_kwargs,
241262
)
242263

243264
batch_limit: int = options.get(
@@ -330,7 +351,7 @@ def _optimize_batch_candidates(
330351
warnings.warn(first_warn_msg, RuntimeWarning)
331352

332353
if not initial_conditions_provided:
333-
batch_initial_conditions = opt_inputs.ic_generator(
354+
batch_initial_conditions = opt_inputs.ic_gen(
334355
acq_function=opt_inputs.acq_function,
335356
bounds=opt_inputs.bounds,
336357
q=opt_inputs.q,
@@ -340,7 +361,7 @@ def _optimize_batch_candidates(
340361
options=options,
341362
inequality_constraints=opt_inputs.inequality_constraints,
342363
equality_constraints=opt_inputs.equality_constraints,
343-
**kwargs,
364+
**opt_inputs.ic_gen_kwargs,
344365
)
345366

346367
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(
@@ -365,7 +386,7 @@ def _optimize_batch_candidates(
365386
batch_candidates = batch_candidates[best]
366387
batch_acq_values = batch_acq_values[best]
367388

368-
if full_tree:
389+
if opt_inputs.full_tree:
369390
batch_candidates = opt_inputs.acq_function.extract_candidates(
370391
X_full=batch_candidates
371392
)
@@ -389,7 +410,11 @@ def optimize_acqf(
389410
return_best_only: bool = True,
390411
gen_candidates: Optional[TGenCandidates] = None,
391412
sequential: bool = False,
392-
**kwargs: Any,
413+
*,
414+
ic_generator: Optional[TGenInitialConditions] = None,
415+
timeout_sec: Optional[float] = None,
416+
return_full_tree: bool = False,
417+
**ic_gen_kwargs: Any,
393418
) -> Tuple[Tensor, Tensor]:
394419
r"""Generate a set of candidates via multi-start optimization.
395420
@@ -435,7 +460,15 @@ def optimize_acqf(
435460
for method-specific inputs. Default: `gen_candidates_scipy`
436461
sequential: If False, uses joint optimization, otherwise uses sequential
437462
optimization.
438-
kwargs: Additonal keyword arguments.
463+
ic_generator: Function for generating initial conditions. Not needed when
464+
`batch_initial_conditions` are provided. Defaults to
465+
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
466+
functions and `gen_batch_initial_conditions` otherwise. Must be specified
467+
for nonlinear inequality constraints.
468+
timeout_sec: Max amount of time optimization can run for.
469+
return_full_tree:
470+
ic_gen_kwargs: Additional keyword arguments passed to function specified by
471+
`ic_generator`
439472
440473
Returns:
441474
A two-element tuple containing
@@ -481,7 +514,10 @@ def optimize_acqf(
481514
return_best_only=return_best_only,
482515
gen_candidates=gen_candidates,
483516
sequential=sequential,
484-
kwargs=kwargs,
517+
ic_generator=ic_generator,
518+
timeout_sec=timeout_sec,
519+
return_full_tree=return_full_tree,
520+
ic_gen_kwargs=ic_gen_kwargs,
485521
)
486522
return _optimize_acqf(opt_acqf_inputs)
487523

@@ -501,8 +537,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor]:
501537
)
502538

503539
start_time: float = time.monotonic()
504-
kwargs = opt_inputs.kwargs
505-
timeout_sec = kwargs.pop("timeout_sec", None)
540+
timeout_sec = opt_inputs.timeout_sec
506541

507542
# Perform sequential optimization via successive conditioning on pending points
508543
if opt_inputs.sequential and opt_inputs.q > 1:
@@ -531,7 +566,11 @@ def optimize_acqf_cyclic(
531566
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
532567
batch_initial_conditions: Optional[Tensor] = None,
533568
cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
534-
**kwargs,
569+
*,
570+
ic_generator: Optional[TGenInitialConditions] = None,
571+
timeout_sec: Optional[float] = None,
572+
return_full_tree: bool = False,
573+
**ic_gen_kwargs: Any,
535574
) -> Tuple[Tensor, Tensor]:
536575
r"""Generate a set of `q` candidates via cyclic optimization.
537576
@@ -561,6 +600,15 @@ def optimize_acqf_cyclic(
561600
If no initial conditions are provided, the default initialization will
562601
be used.
563602
cyclic_options: Options for stopping criterion for outer cyclic optimization.
603+
ic_generator: Function for generating initial conditions. Not needed when
604+
`batch_initial_conditions` are provided. Defaults to
605+
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
606+
functions and `gen_batch_initial_conditions` otherwise. Must be specified
607+
for nonlinear inequality constraints.
608+
timeout_sec: Max amount of time optimization can run for.
609+
return_full_tree:
610+
ic_gen_kwargs: Additional keyword arguments passed to function specified by
611+
`ic_generator`
564612
565613
Returns:
566614
A two-element tuple containing
@@ -596,7 +644,10 @@ def optimize_acqf_cyclic(
596644
return_best_only=True,
597645
gen_candidates=gen_candidates_scipy,
598646
sequential=True,
599-
kwargs=kwargs,
647+
ic_generator=ic_generator,
648+
timeout_sec=timeout_sec,
649+
return_full_tree=return_full_tree,
650+
ic_gen_kwargs=ic_gen_kwargs,
600651
)
601652

602653
# for the first cycle, optimize the q candidates sequentially
@@ -778,6 +829,8 @@ def optimize_acqf_mixed(
778829
transformations).
779830
batch_initial_conditions: A tensor to specify the initial conditions. Set
780831
this if you do not want to use default initialization strategy.
832+
kwargs: kwargs do nothing. This is provided so that the same arguments can
833+
be passed to different acquisition functions without raising an error.
781834
782835
Returns:
783836
A two-element tuple containing
@@ -795,6 +848,7 @@ def optimize_acqf_mixed(
795848
"are currently not supported when `q > 1`. This is needed to "
796849
"compute the joint acquisition value."
797850
)
851+
_raise_deprecation_warning_if_kwargs("optimize_acqf_mixed", kwargs)
798852

799853
if q == 1:
800854
ff_candidate_list, ff_acq_value_list = [], []
@@ -881,6 +935,8 @@ def optimize_acqf_discrete(
881935
a large training set.
882936
unique: If True return unique choices, o/w choices may be repeated
883937
(only relevant if `q > 1`).
938+
kwargs: kwargs do nothing. This is provided so that the same arguments can
939+
be passed to different acquisition functions without raising an error.
884940
885941
Returns:
886942
A three-element tuple containing
@@ -895,6 +951,7 @@ def optimize_acqf_discrete(
895951
)
896952
if choices.numel() == 0:
897953
raise InputDataError("`choices` must be non-emtpy.")
954+
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete", kwargs)
898955
choices_batched = choices.unsqueeze(-2)
899956
if q > 1:
900957
candidate_list, acq_value_list = [], []
@@ -1045,13 +1102,16 @@ def optimize_acqf_discrete_local_search(
10451102
a large training set.
10461103
unique: If True return unique choices, o/w choices may be repeated
10471104
(only relevant if `q > 1`).
1105+
kwargs: kwargs do nothing. This is provided so that the same arguments can
1106+
be passed to different acquisition functions without raising an error.
10481107
10491108
Returns:
10501109
A two-element tuple containing
10511110
10521111
- a `q x d`-dim tensor of generated candidates.
10531112
- an associated acquisition value.
10541113
"""
1114+
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete_local_search", kwargs)
10551115
candidate_list = []
10561116
base_X_pending = acq_function.X_pending if q > 1 else None
10571117
base_X_avoid = X_avoid

test/optim/test_optimize.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,6 @@ def test_optimize_acqf_discrete(self):
13941394

13951395
mock_acq_function = SquaredAcquisitionFunction()
13961396
mock_acq_function.set_X_pending(None)
1397-
13981397
# ensure proper raising of errors if no choices
13991398
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
14001399
optimize_acqf_discrete(
@@ -1404,6 +1403,17 @@ def test_optimize_acqf_discrete(self):
14041403
)
14051404

14061405
choices = torch.rand(5, 2, **tkwargs)
1406+
1407+
# warning for unsupported keyword arguments
1408+
with self.assertWarnsRegex(
1409+
DeprecationWarning,
1410+
r"`optimize_acqf_discrete` does not support arguments "
1411+
r"\['num_restarts'\]. In the future, this will become an error.",
1412+
):
1413+
optimize_acqf_discrete(
1414+
acq_function=mock_acq_function, q=q, choices=choices, num_restarts=8
1415+
)
1416+
14071417
exp_acq_vals = mock_acq_function(choices)
14081418

14091419
# test unique

0 commit comments

Comments
 (0)