30
30
from botorch .optim .initializers import (
31
31
gen_batch_initial_conditions ,
32
32
gen_one_shot_kg_initial_conditions ,
33
+ TGenInitialConditions ,
33
34
)
34
35
from botorch .optim .stopping import ExpMAStoppingCriterion
35
36
from botorch .optim .utils import _filter_kwargs
53
54
}
54
55
55
56
56
- @dataclasses .dataclass
57
+ @dataclasses .dataclass ( frozen = True )
57
58
class OptimizeAcqfInputs :
58
59
"""
59
60
Container for inputs to `optimize_acqf`.
@@ -76,10 +77,19 @@ class OptimizeAcqfInputs:
76
77
return_best_only : bool
77
78
gen_candidates : TGenCandidates
78
79
sequential : bool
79
- kwargs : Dict [str , Any ]
80
- ic_generator : Callable = dataclasses .field (init = False )
80
+ ic_generator : Optional [TGenInitialConditions ] = None
81
+ timeout_sec : Optional [float ] = None
82
+ return_full_tree : bool = False
83
+ ic_gen_kwargs : Dict = dataclasses .field (default_factory = dict )
84
+
85
+ @property
86
+ def full_tree (self ) -> bool :
87
+ return (
88
+ isinstance (self .acq_function , OneShotAcquisitionFunction )
89
+ and not self .return_full_tree
90
+ )
81
91
82
- def _validate (self ) -> None :
92
+ def __post_init__ (self ) -> None :
83
93
if self .inequality_constraints is None and not (
84
94
self .bounds .ndim == 2 and self .bounds .shape [0 ] == 2
85
95
):
@@ -114,7 +124,7 @@ def _validate(self) -> None:
114
124
f"shape is { batch_initial_conditions_shape } ."
115
125
)
116
126
117
- elif "ic_generator" not in self .kwargs . keys () :
127
+ elif not self .ic_generator :
118
128
if self .nonlinear_inequality_constraints :
119
129
raise RuntimeError (
120
130
"`ic_generator` must be given if "
@@ -137,14 +147,31 @@ def _validate(self) -> None:
137
147
"acquisition functions. Must have `sequential=False`."
138
148
)
139
149
140
- def __post_init__ ( self ) -> None :
141
- self . _validate ()
142
- if "ic_generator" in self .kwargs . keys () :
143
- self . ic_generator = self .kwargs . pop ( " ic_generator" )
150
+ @ property
151
+ def ic_gen ( self ) -> TGenInitialConditions :
152
+ if self .ic_generator :
153
+ return self .ic_generator
144
154
elif isinstance (self .acq_function , qKnowledgeGradient ):
145
- self .ic_generator = gen_one_shot_kg_initial_conditions
146
- else :
147
- self .ic_generator = gen_batch_initial_conditions
155
+ return gen_one_shot_kg_initial_conditions
156
+ return gen_batch_initial_conditions
157
+
158
+
159
+ def _raise_deprecation_warning_if_kwargs (fn_name : str , kwargs : Dict [str , Any ]) -> None :
160
+ """
161
+ Raise a warning if kwargs are provided.
162
+
163
+ Some functions used to support **kwargs. The applicable parameters have now been
164
+ refactored to be named arguments, so no warning will be raised for users passing
165
+ the expected arguments. However, if a user had been passing an inapplicable
166
+ keyword argument, this will now raise a warning whereas in the past it did
167
+ nothing.
168
+ """
169
+ if len (kwargs ) > 0 :
170
+ warnings .warn (
171
+ f"`{ fn_name } ` does not support arguments { list (kwargs .keys ())} . In "
172
+ "the future, this will become an error." ,
173
+ DeprecationWarning ,
174
+ )
148
175
149
176
150
177
def _optimize_acqf_all_features_fixed (
@@ -170,34 +197,32 @@ def _optimize_acqf_all_features_fixed(
170
197
171
198
172
199
def _optimize_acqf_sequential_q (
173
- opt_inputs : OptimizeAcqfInputs ,
174
- timeout_sec : Optional [float ],
175
- start_time : float ,
200
+ opt_inputs : OptimizeAcqfInputs , timeout_sec : Optional [float ], start_time : float
176
201
) -> Tuple [Tensor , Tensor ]:
177
202
"""
178
203
Helper function for `optimize_acqf` when sequential=True and q > 1.
179
204
"""
180
- kwargs = opt_inputs .kwargs or {}
181
205
if timeout_sec is not None :
182
206
# When using sequential optimization, we allocate the total timeout
183
207
# evenly across the individual acquisition optimizations.
184
208
timeout_sec = (timeout_sec - start_time ) / opt_inputs .q
185
- kwargs ["timeout_sec" ] = timeout_sec
186
209
187
210
candidate_list , acq_value_list = [], []
188
211
base_X_pending = opt_inputs .acq_function .X_pending
189
212
213
+ new_inputs = dataclasses .replace (
214
+ opt_inputs ,
215
+ q = 1 ,
216
+ batch_initial_conditions = None ,
217
+ return_best_only = True ,
218
+ sequential = False ,
219
+ timeout_sec = timeout_sec ,
220
+ )
190
221
for i in range (opt_inputs .q ):
191
- kwargs ["ic_generator" ] = opt_inputs .ic_generator
192
- new_inputs = dataclasses .replace (
193
- opt_inputs ,
194
- q = 1 ,
195
- batch_initial_conditions = None ,
196
- return_best_only = True ,
197
- sequential = False ,
198
- kwargs = kwargs ,
222
+
223
+ candidate , acq_value = _optimize_acqf_batch (
224
+ new_inputs , start_time = start_time , timeout_sec = timeout_sec
199
225
)
200
- candidate , acq_value = _optimize_acqf (new_inputs )
201
226
202
227
candidate_list .append (candidate )
203
228
acq_value_list .append (acq_value )
@@ -217,17 +242,13 @@ def _optimize_acqf_batch(
217
242
) -> Tuple [Tensor , Tensor ]:
218
243
options = opt_inputs .options or {}
219
244
220
- kwargs = opt_inputs .kwargs
221
- full_tree = isinstance (
222
- opt_inputs .acq_function , OneShotAcquisitionFunction
223
- ) and not kwargs .pop ("return_full_tree" , False )
224
-
225
245
initial_conditions_provided = opt_inputs .batch_initial_conditions is not None
226
246
227
247
if initial_conditions_provided :
228
248
batch_initial_conditions = opt_inputs .batch_initial_conditions
229
249
else :
230
- batch_initial_conditions = opt_inputs .ic_generator (
250
+ # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
251
+ batch_initial_conditions = opt_inputs .ic_gen (
231
252
acq_function = opt_inputs .acq_function ,
232
253
bounds = opt_inputs .bounds ,
233
254
q = opt_inputs .q ,
@@ -237,7 +258,7 @@ def _optimize_acqf_batch(
237
258
options = options ,
238
259
inequality_constraints = opt_inputs .inequality_constraints ,
239
260
equality_constraints = opt_inputs .equality_constraints ,
240
- ** kwargs ,
261
+ ** opt_inputs . ic_gen_kwargs ,
241
262
)
242
263
243
264
batch_limit : int = options .get (
@@ -330,7 +351,7 @@ def _optimize_batch_candidates(
330
351
warnings .warn (first_warn_msg , RuntimeWarning )
331
352
332
353
if not initial_conditions_provided :
333
- batch_initial_conditions = opt_inputs .ic_generator (
354
+ batch_initial_conditions = opt_inputs .ic_gen (
334
355
acq_function = opt_inputs .acq_function ,
335
356
bounds = opt_inputs .bounds ,
336
357
q = opt_inputs .q ,
@@ -340,7 +361,7 @@ def _optimize_batch_candidates(
340
361
options = options ,
341
362
inequality_constraints = opt_inputs .inequality_constraints ,
342
363
equality_constraints = opt_inputs .equality_constraints ,
343
- ** kwargs ,
364
+ ** opt_inputs . ic_gen_kwargs ,
344
365
)
345
366
346
367
batch_candidates , batch_acq_values , ws = _optimize_batch_candidates (
@@ -365,7 +386,7 @@ def _optimize_batch_candidates(
365
386
batch_candidates = batch_candidates [best ]
366
387
batch_acq_values = batch_acq_values [best ]
367
388
368
- if full_tree :
389
+ if opt_inputs . full_tree :
369
390
batch_candidates = opt_inputs .acq_function .extract_candidates (
370
391
X_full = batch_candidates
371
392
)
@@ -389,7 +410,11 @@ def optimize_acqf(
389
410
return_best_only : bool = True ,
390
411
gen_candidates : Optional [TGenCandidates ] = None ,
391
412
sequential : bool = False ,
392
- ** kwargs : Any ,
413
+ * ,
414
+ ic_generator : Optional [TGenInitialConditions ] = None ,
415
+ timeout_sec : Optional [float ] = None ,
416
+ return_full_tree : bool = False ,
417
+ ** ic_gen_kwargs : Any ,
393
418
) -> Tuple [Tensor , Tensor ]:
394
419
r"""Generate a set of candidates via multi-start optimization.
395
420
@@ -435,7 +460,15 @@ def optimize_acqf(
435
460
for method-specific inputs. Default: `gen_candidates_scipy`
436
461
sequential: If False, uses joint optimization, otherwise uses sequential
437
462
optimization.
438
- kwargs: Additonal keyword arguments.
463
+ ic_generator: Function for generating initial conditions. Not needed when
464
+ `batch_initial_conditions` are provided. Defaults to
465
+ `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
466
+ functions and `gen_batch_initial_conditions` otherwise. Must be specified
467
+ for nonlinear inequality constraints.
468
+ timeout_sec: Max amount of time optimization can run for.
469
+ return_full_tree:
470
+ ic_gen_kwargs: Additional keyword arguments passed to function specified by
471
+ `ic_generator`
439
472
440
473
Returns:
441
474
A two-element tuple containing
@@ -481,7 +514,10 @@ def optimize_acqf(
481
514
return_best_only = return_best_only ,
482
515
gen_candidates = gen_candidates ,
483
516
sequential = sequential ,
484
- kwargs = kwargs ,
517
+ ic_generator = ic_generator ,
518
+ timeout_sec = timeout_sec ,
519
+ return_full_tree = return_full_tree ,
520
+ ic_gen_kwargs = ic_gen_kwargs ,
485
521
)
486
522
return _optimize_acqf (opt_acqf_inputs )
487
523
@@ -501,8 +537,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor]:
501
537
)
502
538
503
539
start_time : float = time .monotonic ()
504
- kwargs = opt_inputs .kwargs
505
- timeout_sec = kwargs .pop ("timeout_sec" , None )
540
+ timeout_sec = opt_inputs .timeout_sec
506
541
507
542
# Perform sequential optimization via successive conditioning on pending points
508
543
if opt_inputs .sequential and opt_inputs .q > 1 :
@@ -531,7 +566,11 @@ def optimize_acqf_cyclic(
531
566
post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
532
567
batch_initial_conditions : Optional [Tensor ] = None ,
533
568
cyclic_options : Optional [Dict [str , Union [bool , float , int , str ]]] = None ,
534
- ** kwargs ,
569
+ * ,
570
+ ic_generator : Optional [TGenInitialConditions ] = None ,
571
+ timeout_sec : Optional [float ] = None ,
572
+ return_full_tree : bool = False ,
573
+ ** ic_gen_kwargs : Any ,
535
574
) -> Tuple [Tensor , Tensor ]:
536
575
r"""Generate a set of `q` candidates via cyclic optimization.
537
576
@@ -561,6 +600,15 @@ def optimize_acqf_cyclic(
561
600
If no initial conditions are provided, the default initialization will
562
601
be used.
563
602
cyclic_options: Options for stopping criterion for outer cyclic optimization.
603
+ ic_generator: Function for generating initial conditions. Not needed when
604
+ `batch_initial_conditions` are provided. Defaults to
605
+ `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
606
+ functions and `gen_batch_initial_conditions` otherwise. Must be specified
607
+ for nonlinear inequality constraints.
608
+ timeout_sec: Max amount of time optimization can run for.
609
+ return_full_tree:
610
+ ic_gen_kwargs: Additional keyword arguments passed to function specified by
611
+ `ic_generator`
564
612
565
613
Returns:
566
614
A two-element tuple containing
@@ -596,7 +644,10 @@ def optimize_acqf_cyclic(
596
644
return_best_only = True ,
597
645
gen_candidates = gen_candidates_scipy ,
598
646
sequential = True ,
599
- kwargs = kwargs ,
647
+ ic_generator = ic_generator ,
648
+ timeout_sec = timeout_sec ,
649
+ return_full_tree = return_full_tree ,
650
+ ic_gen_kwargs = ic_gen_kwargs ,
600
651
)
601
652
602
653
# for the first cycle, optimize the q candidates sequentially
@@ -778,6 +829,8 @@ def optimize_acqf_mixed(
778
829
transformations).
779
830
batch_initial_conditions: A tensor to specify the initial conditions. Set
780
831
this if you do not want to use default initialization strategy.
832
+ kwargs: kwargs do nothing. This is provided so that the same arguments can
833
+ be passed to different acquisition functions without raising an error.
781
834
782
835
Returns:
783
836
A two-element tuple containing
@@ -795,6 +848,7 @@ def optimize_acqf_mixed(
795
848
"are currently not supported when `q > 1`. This is needed to "
796
849
"compute the joint acquisition value."
797
850
)
851
+ _raise_deprecation_warning_if_kwargs ("optimize_acqf_mixed" , kwargs )
798
852
799
853
if q == 1 :
800
854
ff_candidate_list , ff_acq_value_list = [], []
@@ -881,6 +935,8 @@ def optimize_acqf_discrete(
881
935
a large training set.
882
936
unique: If True return unique choices, o/w choices may be repeated
883
937
(only relevant if `q > 1`).
938
+ kwargs: kwargs do nothing. This is provided so that the same arguments can
939
+ be passed to different acquisition functions without raising an error.
884
940
885
941
Returns:
886
942
A three-element tuple containing
@@ -895,6 +951,7 @@ def optimize_acqf_discrete(
895
951
)
896
952
if choices .numel () == 0 :
897
953
raise InputDataError ("`choices` must be non-emtpy." )
954
+ _raise_deprecation_warning_if_kwargs ("optimize_acqf_discrete" , kwargs )
898
955
choices_batched = choices .unsqueeze (- 2 )
899
956
if q > 1 :
900
957
candidate_list , acq_value_list = [], []
@@ -1045,13 +1102,16 @@ def optimize_acqf_discrete_local_search(
1045
1102
a large training set.
1046
1103
unique: If True return unique choices, o/w choices may be repeated
1047
1104
(only relevant if `q > 1`).
1105
+ kwargs: kwargs do nothing. This is provided so that the same arguments can
1106
+ be passed to different acquisition functions without raising an error.
1048
1107
1049
1108
Returns:
1050
1109
A two-element tuple containing
1051
1110
1052
1111
- a `q x d`-dim tensor of generated candidates.
1053
1112
- an associated acquisition value.
1054
1113
"""
1114
+ _raise_deprecation_warning_if_kwargs ("optimize_acqf_discrete_local_search" , kwargs )
1055
1115
candidate_list = []
1056
1116
base_X_pending = acq_function .X_pending if q > 1 else None
1057
1117
base_X_avoid = X_avoid
0 commit comments