@@ -109,12 +109,43 @@ def __post_init__(self) -> None:
109109 "3-dimensional. Its shape is "
110110 f"{ batch_initial_conditions_shape } ."
111111 )
112+
112113 if batch_initial_conditions_shape [- 1 ] != d :
113114 raise ValueError (
114115 f"batch_initial_conditions.shape[-1] must be { d } . The "
115116 f"shape is { batch_initial_conditions_shape } ."
116117 )
117118
119+ if len (batch_initial_conditions_shape ) == 2 :
120+ warnings .warn (
121+ "If using a 2-dim `batch_initial_conditions` botorch will "
122+ "default to old behavior of ignoring `num_restarts` and just "
123+ "use the given `batch_initial_conditions` by setting "
124+ "`raw_samples` to None." ,
125+ RuntimeWarning ,
126+ stacklevel = 3 ,
127+ )
128+ # Use object.__setattr__ to bypass immutability and set a value
129+ object .__setattr__ (self , "raw_samples" , None )
130+
131+ if (
132+ len (batch_initial_conditions_shape ) == 3
133+ and batch_initial_conditions_shape [0 ] < self .num_restarts
134+ and batch_initial_conditions_shape [- 2 ] != self .q
135+ ):
136+ warnings .warn (
137+ "If using a 3-dim `batch_initial_conditions` where the "
138+ "first dimension is less than `num_restarts` and the second "
139+ "dimension is not equal to `q`, botorch will default to "
140+ "old behavior of ignoring `num_restarts` and just use the "
141+ "given `batch_initial_conditions` by setting `raw_samples` "
142+ "to None." ,
143+ RuntimeWarning ,
144+ stacklevel = 3 ,
145+ )
146+ # Use object.__setattr__ to bypass immutability and set a value
147+ object .__setattr__ (self , "raw_samples" , None )
148+
118149 elif self .ic_generator is None :
119150 if self .nonlinear_inequality_constraints is not None :
120151 raise RuntimeError (
@@ -126,6 +157,7 @@ def __post_init__(self) -> None:
126157 "Must specify `raw_samples` when "
127158 "`batch_initial_conditions` is None`."
128159 )
160+
129161 if self .fixed_features is not None and any (
130162 (k < 0 for k in self .fixed_features )
131163 ):
@@ -253,20 +285,49 @@ def _optimize_acqf_sequential_q(
253285 return candidates , torch .stack (acq_value_list )
254286
255287
288+ def _combine_initial_conditions (
289+ provided_initial_conditions : Tensor | None = None ,
290+ generated_initial_conditions : Tensor | None = None ,
291+ dim = 0 ,
292+ ) -> Tensor :
293+ if (
294+ provided_initial_conditions is not None
295+ and generated_initial_conditions is not None
296+ ):
297+ return torch .cat (
298+ [provided_initial_conditions , generated_initial_conditions ], dim = dim
299+ )
300+ elif provided_initial_conditions is not None :
301+ return provided_initial_conditions
302+ elif generated_initial_conditions is not None :
303+ return generated_initial_conditions
304+ else :
305+ raise ValueError (
306+ "Either `batch_initial_conditions` or `raw_samples` must be set."
307+ )
308+
309+
256310def _optimize_acqf_batch (opt_inputs : OptimizeAcqfInputs ) -> tuple [Tensor , Tensor ]:
257311 options = opt_inputs .options or {}
258312
259- initial_conditions_provided = opt_inputs .batch_initial_conditions is not None
313+ required_num_restarts = opt_inputs .num_restarts
314+ provided_initial_conditions = opt_inputs .batch_initial_conditions
315+ generated_initial_conditions = None
260316
261- if initial_conditions_provided :
262- batch_initial_conditions = opt_inputs .batch_initial_conditions
263- else :
264- # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265- batch_initial_conditions = opt_inputs .get_ic_generator ()(
317+ if (
318+ provided_initial_conditions is not None
319+ and len (provided_initial_conditions .shape ) == 3
320+ ):
321+ required_num_restarts -= provided_initial_conditions .shape [0 ]
322+
323+ if opt_inputs .raw_samples is not None and required_num_restarts > 0 :
324+ # pyre-ignore[28]: Unexpected keyword argument `acq_function`
325+ # to anonymous call.
326+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
266327 acq_function = opt_inputs .acq_function ,
267328 bounds = opt_inputs .bounds ,
268329 q = opt_inputs .q ,
269- num_restarts = opt_inputs . num_restarts ,
330+ num_restarts = required_num_restarts ,
270331 raw_samples = opt_inputs .raw_samples ,
271332 fixed_features = opt_inputs .fixed_features ,
272333 options = options ,
@@ -275,6 +336,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
275336 ** opt_inputs .ic_gen_kwargs ,
276337 )
277338
339+ batch_initial_conditions = _combine_initial_conditions (
340+ provided_initial_conditions = provided_initial_conditions ,
341+ generated_initial_conditions = generated_initial_conditions ,
342+ )
343+
278344 batch_limit : int = options .get (
279345 "batch_limit" ,
280346 (
@@ -344,23 +410,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344410 first_warn_msg = (
345411 "Optimization failed in `gen_candidates_scipy` with the following "
346412 f"warning(s):\n { [w .message for w in ws ]} \n Because you specified "
347- "`batch_initial_conditions`, optimization will not be retried with "
348- "new initial conditions and will proceed with the current solution."
349- " Suggested remediation: Try again with different "
350- "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351- if initial_conditions_provided
413+ "`batch_initial_conditions` larger than required `num_restarts`, "
414+ "optimization will not be retried with new initial conditions and "
415+ "will proceed with the current solution. Suggested remediation: "
416+ "Try again with different `batch_initial_conditions`, don't provide "
417+ "`batch_initial_conditions`, or increase `num_restarts`."
418+ if batch_initial_conditions is not None and required_num_restarts <= 0
352419 else "Optimization failed in `gen_candidates_scipy` with the following "
353420 f"warning(s):\n { [w .message for w in ws ]} \n Trying again with a new "
354421 "set of initial conditions."
355422 )
356423 warnings .warn (first_warn_msg , RuntimeWarning , stacklevel = 2 )
357424
358- if not initial_conditions_provided :
359- batch_initial_conditions = opt_inputs .get_ic_generator ()(
425+ if opt_inputs . raw_samples is not None and required_num_restarts > 0 :
426+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
360427 acq_function = opt_inputs .acq_function ,
361428 bounds = opt_inputs .bounds ,
362429 q = opt_inputs .q ,
363- num_restarts = opt_inputs . num_restarts ,
430+ num_restarts = required_num_restarts ,
364431 raw_samples = opt_inputs .raw_samples ,
365432 fixed_features = opt_inputs .fixed_features ,
366433 options = options ,
@@ -369,6 +436,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
369436 ** opt_inputs .ic_gen_kwargs ,
370437 )
371438
439+ batch_initial_conditions = _combine_initial_conditions (
440+ provided_initial_conditions = provided_initial_conditions ,
441+ generated_initial_conditions = generated_initial_conditions ,
442+ )
443+
372444 batch_candidates , batch_acq_values , ws = _optimize_batch_candidates ()
373445
374446 optimization_warning_raised = any (
@@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
11771249 inequality_constraints : list [tuple [Tensor , Tensor , float ]],
11781250 min_points : int ,
11791251 max_tries : int = 100 ,
1180- ):
1252+ ) -> Tensor :
11811253 """Generate initial conditions for local search."""
11821254 device = discrete_choices [0 ].device
11831255 dtype = discrete_choices [0 ].dtype
@@ -1197,6 +1269,58 @@ def _gen_batch_initial_conditions_local_search(
11971269 raise RuntimeError (f"Failed to generate at least { min_points } initial conditions" )
11981270
11991271
1272+ def _gen_starting_points_local_search (
1273+ discrete_choices : list [Tensor ],
1274+ raw_samples : int ,
1275+ batch_initial_conditions : Tensor ,
1276+ X_avoid : Tensor ,
1277+ inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1278+ min_points : int ,
1279+ acq_function : AcquisitionFunction ,
1280+ max_batch_size : int = 2048 ,
1281+ max_tries : int = 100 ,
1282+ ) -> Tensor :
1283+ required_min_points = min_points
1284+ provided_X0 = None
1285+ generated_X0 = None
1286+
1287+ if batch_initial_conditions is not None :
1288+ provided_X0 = _filter_invalid (
1289+ X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid
1290+ )
1291+ provided_X0 = _filter_infeasible (
1292+ X = provided_X0 , inequality_constraints = inequality_constraints
1293+ ).unsqueeze (1 )
1294+ required_min_points -= batch_initial_conditions .shape [0 ]
1295+
1296+ if required_min_points > 0 :
1297+ generated_X0 = _gen_batch_initial_conditions_local_search (
1298+ discrete_choices = discrete_choices ,
1299+ raw_samples = raw_samples ,
1300+ X_avoid = X_avoid ,
1301+ inequality_constraints = inequality_constraints ,
1302+ min_points = min_points ,
1303+ max_tries = max_tries ,
1304+ )
1305+
1306+ # pick the best starting points
1307+ with torch .no_grad ():
1308+ acqvals_init = _split_batch_eval_acqf (
1309+ acq_function = acq_function ,
1310+ X = generated_X0 .unsqueeze (1 ),
1311+ max_batch_size = max_batch_size ,
1312+ ).unsqueeze (- 1 )
1313+
1314+ generated_X0 = generated_X0 [
1315+ acqvals_init .topk (k = min_points , largest = True , dim = 0 ).indices
1316+ ]
1317+
1318+ return _combine_initial_conditions (
1319+ provided_initial_conditions = provided_X0 if provided_X0 is not None else None ,
1320+ generated_initial_conditions = generated_X0 if generated_X0 is not None else None ,
1321+ )
1322+
1323+
12001324def optimize_acqf_discrete_local_search (
12011325 acq_function : AcquisitionFunction ,
12021326 discrete_choices : list [Tensor ],
@@ -1207,6 +1331,7 @@ def optimize_acqf_discrete_local_search(
12071331 X_avoid : Tensor | None = None ,
12081332 batch_initial_conditions : Tensor | None = None ,
12091333 max_batch_size : int = 2048 ,
1334+ max_tries : int = 100 ,
12101335 unique : bool = True ,
12111336) -> tuple [Tensor , Tensor ]:
12121337 r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1363,8 @@ def optimize_acqf_discrete_local_search(
12381363 max_batch_size: The maximum number of choices to evaluate in batch.
12391364 A large limit can cause excessive memory usage if the model has
12401365 a large training set.
1366+ max_tries: Maximum number of iterations to try when generating initial
1367+ conditions.
12411368 unique: If True return unique choices, o/w choices may be repeated
12421369 (only relevant if `q > 1`).
12431370
@@ -1247,6 +1374,16 @@ def optimize_acqf_discrete_local_search(
12471374 - a `q x d`-dim tensor of generated candidates.
12481375 - an associated acquisition value.
12491376 """
1377+ if batch_initial_conditions is not None :
1378+ if not (
1379+ len (batch_initial_conditions .shape ) == 3
1380+ and batch_initial_conditions .shape [- 2 ] == 1
1381+ ):
1382+ raise ValueError (
1383+ "batch_initial_conditions must have shape `n x 1 x d` if "
1384+ f"given (received shape { batch_initial_conditions .shape } )."
1385+ )
1386+
12501387 candidate_list = []
12511388 base_X_pending = acq_function .X_pending if q > 1 else None
12521389 base_X_avoid = X_avoid
@@ -1259,27 +1396,18 @@ def optimize_acqf_discrete_local_search(
12591396 inequality_constraints = inequality_constraints or []
12601397 for i in range (q ):
12611398 # generate some starting points
1262- if i == 0 and batch_initial_conditions is not None :
1263- X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1264- X0 = _filter_infeasible (
1265- X = X0 , inequality_constraints = inequality_constraints
1266- ).unsqueeze (1 )
1267- else :
1268- X_init = _gen_batch_initial_conditions_local_search (
1269- discrete_choices = discrete_choices ,
1270- raw_samples = raw_samples ,
1271- X_avoid = X_avoid ,
1272- inequality_constraints = inequality_constraints ,
1273- min_points = num_restarts ,
1274- )
1275- # pick the best starting points
1276- with torch .no_grad ():
1277- acqvals_init = _split_batch_eval_acqf (
1278- acq_function = acq_function ,
1279- X = X_init .unsqueeze (1 ),
1280- max_batch_size = max_batch_size ,
1281- ).unsqueeze (- 1 )
1282- X0 = X_init [acqvals_init .topk (k = num_restarts , largest = True , dim = 0 ).indices ]
1399+ X0 = _gen_starting_points_local_search (
1400+ discrete_choices = discrete_choices ,
1401+ raw_samples = raw_samples ,
1402+ batch_initial_conditions = batch_initial_conditions ,
1403+ X_avoid = X_avoid ,
1404+ inequality_constraints = inequality_constraints ,
1405+ min_points = num_restarts ,
1406+ acq_function = acq_function ,
1407+ max_batch_size = max_batch_size ,
1408+ max_tries = max_tries ,
1409+ )
1410+ batch_initial_conditions = None
12831411
12841412 # optimize from the best starting points
12851413 best_xs = torch .zeros (len (X0 ), dim , device = device , dtype = dtype )
0 commit comments