1414import warnings
1515from collections .abc import Callable
1616from functools import partial
17- from typing import Any , NoReturn
17+ from typing import Any , Mapping , NoReturn
1818
1919import numpy as np
2020import numpy .typing as npt
@@ -64,7 +64,7 @@ def gen_candidates_scipy(
6464 equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
6565 nonlinear_inequality_constraints : list [tuple [Callable , bool ]] | None = None ,
6666 options : dict [str , Any ] | None = None ,
67- fixed_features : dict [int , float | None ] | None = None ,
67+ fixed_features : Mapping [int , float | Tensor ] | None = None ,
6868 timeout_sec : float | None = None ,
6969 use_parallel_mode : bool | None = None ,
7070) -> tuple [Tensor , Tensor ]:
@@ -107,11 +107,11 @@ def gen_candidates_scipy(
107107 and SLSQP if inequality or equality constraints are present. If
108108 `with_grad=False`, then we use a two-point finite difference estimate
109109 of the gradient.
110- fixed_features: This is a dictionary of feature indices to values, where
110+ fixed_features: Mapping[int, float | Tensor] | None,
111111 all generated candidates will have features fixed to these values.
112- If the dictionary value is None, then that feature will just be
113- fixed to the clamped value and not optimized. Assumes values to be
114- compatible with lower_bounds and upper_bounds!
112+ If passing tensors as values, they should have either shape `b` or
113+ `b x q` to fix the same feature to different values in the batch.
114+ Assumes values to be compatible with lower_bounds and upper_bounds!
115115 timeout_sec: Timeout (in seconds) for `scipy.optimize.minimize` routine -
116116 if provided, optimization will stop after this many seconds and return
117117 the best solution found so far.
@@ -211,18 +211,17 @@ def f(x):
211211 timeout_sec = timeout_sec ,
212212 )
213213
214+ f_np_wrapper = _get_f_np_wrapper (
215+ clamped_candidates .shape ,
216+ initial_conditions .device ,
217+ initial_conditions .dtype ,
218+ with_grad ,
219+ )
220+
214221 if not why_not_fast_path and use_parallel_mode is not False :
215222 if is_constrained :
216223 raise RuntimeWarning ("Method L-BFGS-B cannot handle constraints." )
217224
218- f_np_wrapper = _get_f_np_wrapper (
219- clamped_candidates .shape ,
220- initial_conditions .device ,
221- initial_conditions .dtype ,
222- with_grad ,
223- batched = True ,
224- )
225-
226225 batched_x0 = _arrayify (clamped_candidates ).reshape (len (clamped_candidates ), - 1 )
227226
228227 l_bfgs_b_bounds = translate_bounds_for_lbfgsb (
@@ -242,6 +241,7 @@ def f(x):
242241 bounds = l_bfgs_b_bounds ,
243242 # constraints=constraints,
244243 callback = options .get ("callback" , None ),
244+ pass_batch_indices = True ,
245245 ** minimize_options ,
246246 )
247247 for res in results :
@@ -264,21 +264,38 @@ def f(x):
264264 else :
265265 logger .debug (msg )
266266
267- f_np_wrapper = _get_f_np_wrapper (
268- clamped_candidates .shape ,
269- initial_conditions .device ,
270- initial_conditions .dtype ,
271- with_grad ,
272- )
267+ if (
268+ fixed_features
269+ and any (
270+ torch .is_tensor (ff ) and ff .ndim > 0 for ff in fixed_features .values ()
271+ )
272+ and max_optimization_problem_aggregation_size != 1
273+ ):
274+ raise UnsupportedError (
275+ "Batch shaped fixed features are not "
276+ "supported, when optimizing more than one optimization "
277+ "problem at a time."
278+ )
273279
274280 all_xs = []
275281 split_candidates = clamped_candidates .split (
276282 max_optimization_problem_aggregation_size
277283 )
278- for candidates_ in split_candidates :
279- # We optimize the candidates at hand as a single problem
284+ for i , candidates_ in enumerate (split_candidates ):
285+ if fixed_features :
286+ fixed_features_ = {
287+ k : ff [i : i + 1 ].item ()
288+ # from the test above, we know that we only treat one candidate
289+ # at a time thus we can use index i
290+ if torch .is_tensor (ff ) and ff .ndim > 0
291+ else ff
292+ for k , ff in fixed_features .items ()
293+ }
294+ else :
295+ fixed_features_ = None
296+
280297 _no_fixed_features = _remove_fixed_features_from_optimization (
281- fixed_features = fixed_features ,
298+ fixed_features = fixed_features_ ,
282299 acquisition_function = acquisition_function ,
283300 initial_conditions = None ,
284301 d = initial_conditions_all_features .shape [- 1 ],
@@ -296,7 +313,7 @@ def f(x):
296313
297314 f_np_wrapper_ = partial (
298315 f_np_wrapper ,
299- fixed_features = fixed_features ,
316+ fixed_features = fixed_features_ ,
300317 )
301318
302319 x0 = candidates_ .flatten ()
@@ -363,13 +380,14 @@ def f(x):
363380 return clamped_candidates , batch_acquisition
364381
365382
366- def _get_f_np_wrapper (shapeX , device , dtype , with_grad , batched = False ):
383+ def _get_f_np_wrapper (shapeX , device , dtype , with_grad ):
367384 if with_grad :
368385
369386 def f_np_wrapper (
370387 x : npt .NDArray ,
371388 f : Callable ,
372- fixed_features : dict [int , float ] | None ,
389+ fixed_features : Mapping [int , float | Tensor ] | None ,
390+ batch_indices : list [int ] | None = None ,
373391 ) -> tuple [float | np .NDArray , np .NDArray ]:
374392 """Given a torch callable, compute value + grad given a numpy array."""
375393 if np .isnan (x ).any ():
@@ -387,8 +405,21 @@ def f_np_wrapper(
387405 .contiguous ()
388406 .requires_grad_ (True )
389407 )
408+ if fixed_features is not None :
409+ if batch_indices is not None :
410+ this_fixed_features = {
411+ k : ff [batch_indices ]
412+ if torch .is_tensor (ff ) and ff .ndim > 0
413+ else ff
414+ for k , ff in fixed_features .items ()
415+ }
416+ else :
417+ this_fixed_features = fixed_features
418+ else :
419+ this_fixed_features = None
420+
390421 X_fix = fix_features (
391- X , fixed_features = fixed_features , replace_current_value = False
422+ X , fixed_features = this_fixed_features , replace_current_value = False
392423 )
393424 # we compute the loss on the whole batch, under the assumption that f
394425 # treats multiple inputs in the 0th dimension as independent
@@ -409,7 +440,7 @@ def f_np_wrapper(
409440 raise OptimizationGradientError (msg , current_x = x )
410441 fval = (
411442 losses .detach ().view (- 1 ).cpu ().numpy ()
412- if batched
443+ if batch_indices is not None
413444 else loss .detach ().item ()
414445 ) # the view(-1) seems necessary as f might return a single scalar
415446 return fval , gradf
@@ -485,7 +516,7 @@ def gen_candidates_torch(
485516 optimizer : type [Optimizer ] = torch .optim .Adam ,
486517 options : dict [str , float | str ] | None = None ,
487518 callback : Callable [[int , Tensor , Tensor ], NoReturn ] | None = None ,
488- fixed_features : dict [int , float | None ] | None = None ,
519+ fixed_features : Mapping [int , float | Tensor ] | None = None ,
489520 timeout_sec : float | None = None ,
490521) -> tuple [Tensor , Tensor ]:
491522 r"""Generate a set of candidates using a `torch.optim` optimizer.
@@ -507,9 +538,10 @@ def gen_candidates_torch(
507538 the loss and gradients, but before calling the optimizer.
508539 fixed_features: This is a dictionary of feature indices to values, where
509540 all generated candidates will have features fixed to these values.
510- If the dictionary value is None, then that feature will just be
511- fixed to the clamped value and not optimized. Assumes values to be
512- compatible with lower_bounds and upper_bounds!
541+ If a float is passed it is fixed across [b,q], if a tensor is passed:
542+ it might either be of shape [b,q] or [b], in which case the same value
543+ is used across the q dimension.
544+ Assumes values to be compatible with lower_bounds and upper_bounds!
513545 timeout_sec: Timeout (in seconds) for optimization. If provided,
514546 `gen_candidates_torch` will stop after this many seconds and return
515547 the best solution found so far.
@@ -533,12 +565,18 @@ def gen_candidates_torch(
533565 upper_bounds=bounds[1],
534566 )
535567 """
536- assert not fixed_features or not any (
537- torch .is_tensor (v ) for v in fixed_features .values ()
538- ), "`gen_candidates_torch` does not support tensor-valued fixed features."
568+ if fixed_features and any (torch .is_tensor (v ) for v in fixed_features .values ()):
569+ raise UnsupportedError (
570+ "`gen_candidates_torch` does not support tensor-valued fixed features."
571+ )
539572
540573 start_time = time .monotonic ()
541574 options = options or {}
575+ # We remove max_optimization_problem_aggregation_size as it does not affect
576+ # the 1st order optimizers implemented in this method.
577+ # Here, it does not matter whether one combines multiple optimizations into
578+ # one or not.
579+ options .pop ("max_optimization_problem_aggregation_size" , None )
542580
543581 # if there are fixed features we may optimize over a domain of lower dimension
544582 if fixed_features :
@@ -572,7 +610,13 @@ def gen_candidates_torch(
572610 )
573611 return clamped_candidates , batch_acquisition
574612 _clamp = partial (columnwise_clamp , lower = lower_bounds , upper = upper_bounds )
575- clamped_candidates = _clamp (initial_conditions ).requires_grad_ (True )
613+ clamped_candidates = _clamp (initial_conditions )
614+ if fixed_features :
615+ clamped_candidates = clamped_candidates [
616+ ...,
617+ [i for i in range (clamped_candidates .shape [- 1 ]) if i not in fixed_features ],
618+ ]
619+ clamped_candidates = clamped_candidates .requires_grad_ (True )
576620 _optimizer = optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
577621
578622 i = 0
@@ -583,7 +627,7 @@ def gen_candidates_torch(
583627 with torch .no_grad ():
584628 X = _clamp (clamped_candidates ).requires_grad_ (True )
585629
586- loss = - acquisition_function (X ).sum ()
630+ loss = - acquisition_function (fix_features ( X , fixed_features ) ).sum ()
587631 grad = torch .autograd .grad (loss , X )[0 ]
588632 if callback :
589633 callback (i , loss , grad )
@@ -602,6 +646,7 @@ def assign_grad():
602646 logger .info (f"Optimization timed out after { runtime } seconds." )
603647
604648 clamped_candidates = _clamp (clamped_candidates )
649+ clamped_candidates = fix_features (clamped_candidates , fixed_features )
605650 with torch .no_grad ():
606651 batch_acquisition = acquisition_function (clamped_candidates )
607652
0 commit comments