@@ -65,7 +65,7 @@ class OptimizeAcqfInputs:
6565 See docstring for `optimize_acqf` for explanation of parameters.
6666 """
6767
68- acq_function : AcquisitionFunction
68+ acq_function : AcquisitionFunction | None
6969 bounds : Tensor
7070 q : int
7171 num_restarts : int
@@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
8585 return_full_tree : bool = False
8686 retry_on_optimization_warning : bool = True
8787 ic_gen_kwargs : dict = dataclasses .field (default_factory = dict )
88+ acq_function_sequence : list [AcquisitionFunction ] | None = None
8889
8990 @property
9091 def full_tree (self ) -> bool :
@@ -93,6 +94,10 @@ def full_tree(self) -> bool:
9394 )
9495
9596 def __post_init__ (self ) -> None :
97+ if self .acq_function is None and self .acq_function_sequence is None :
98+ raise ValueError (
99+ "Either `acq_function` or `acq_function_sequence` must be specified."
100+ )
96101 if self .inequality_constraints is None and not (
97102 self .bounds .ndim == 2 and self .bounds .shape [0 ] == 2
98103 ):
@@ -168,6 +173,16 @@ def __post_init__(self) -> None:
168173 ):
169174 raise ValueError ("All indices (keys) in `fixed_features` must be >= 0." )
170175
176+ if self .acq_function_sequence is not None :
177+ if not self .sequential :
178+ raise ValueError (
179+ "acq_function_sequence requires sequential optimization."
180+ )
181+ if len (self .acq_function_sequence ) != self .q :
182+ raise ValueError ("acq_function_sequence must have length q." )
183+ if self .q < 2 :
184+ raise ValueError ("acq_function_sequence requires q > 1." )
185+
171186 def get_ic_generator (self ) -> TGenInitialConditions :
172187 if self .ic_generator is not None :
173188 return self .ic_generator
@@ -264,29 +279,47 @@ def _optimize_acqf_sequential_q(
264279 else None
265280 )
266281 candidate_list , acq_value_list = [], []
267- base_X_pending = opt_inputs .acq_function .X_pending
282+ if opt_inputs .acq_function_sequence is None :
283+ acq_function_sequence = [opt_inputs .acq_function ]
284+ else :
285+ acq_function_sequence = opt_inputs .acq_function_sequence
286+ base_X_pending = [acqf .X_pending for acqf in acq_function_sequence ]
287+ n_acq = len (acq_function_sequence )
288+
289+ new_kwargs = {
290+ "q" : 1 ,
291+ "batch_initial_conditions" : None ,
292+ "return_best_only" : True ,
293+ "sequential" : False ,
294+ "timeout_sec" : timeout_sec ,
295+ "acq_function_sequence" : None ,
296+ }
297+ new_inputs = dataclasses .replace (opt_inputs , ** new_kwargs )
268298
269- new_inputs = dataclasses .replace (
270- opt_inputs ,
271- q = 1 ,
272- batch_initial_conditions = None ,
273- return_best_only = True ,
274- sequential = False ,
275- timeout_sec = timeout_sec ,
276- )
277299 for i in range (opt_inputs .q ):
300+ if n_acq > 1 :
301+ acq_function = acq_function_sequence [i ]
302+ new_kwargs ["acq_function" ] = acq_function
303+ new_inputs = dataclasses .replace (opt_inputs , ** new_kwargs )
304+ if len (candidate_list ) > 0 :
305+ candidates = torch .cat (candidate_list , dim = - 2 )
306+ new_inputs .acq_function .set_X_pending (
307+ torch .cat ([base_X_pending [i % n_acq ], candidates ], dim = - 2 )
308+ if base_X_pending [i % n_acq ] is not None
309+ else candidates
310+ )
278311 candidate , acq_value = _optimize_acqf_batch (new_inputs )
279312
280313 candidate_list .append (candidate )
281314 acq_value_list .append (acq_value )
282- candidates = torch .cat (candidate_list , dim = - 2 )
283- new_inputs .acq_function .set_X_pending (
284- torch .cat ([base_X_pending , candidates ], dim = - 2 )
285- if base_X_pending is not None
286- else candidates
287- )
315+
288316 logger .info (f"Generated sequential candidate { i + 1 } of { opt_inputs .q } " )
289- opt_inputs .acq_function .set_X_pending (base_X_pending )
317+ model_name = type (new_inputs .acq_function .model ).__name__
318+ logger .debug (f"Used model { model_name } for candidate generation." )
319+ candidates = torch .cat (candidate_list , dim = - 2 )
320+ # Re-set X_pendings on the acquisitions to base values
321+ for acqf , X_pending in zip (acq_function_sequence , base_X_pending ):
322+ acqf .set_X_pending (X_pending )
290323 return candidates , torch .stack (acq_value_list )
291324
292325
@@ -517,7 +550,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
517550
518551
519552def optimize_acqf (
520- acq_function : AcquisitionFunction ,
553+ acq_function : AcquisitionFunction | None ,
521554 bounds : Tensor ,
522555 q : int ,
523556 num_restarts : int ,
@@ -532,6 +565,7 @@ def optimize_acqf(
532565 return_best_only : bool = True ,
533566 gen_candidates : TGenCandidates | None = None ,
534567 sequential : bool = False ,
568+ acq_function_sequence : list [AcquisitionFunction ] | None = None ,
535569 * ,
536570 ic_generator : TGenInitialConditions | None = None ,
537571 timeout_sec : float | None = None ,
@@ -627,6 +661,10 @@ def optimize_acqf(
627661 inputs. Default: `gen_candidates_scipy`
628662 sequential: If False, uses joint optimization, otherwise uses sequential
629663 optimization for optimizing multiple joint candidates (q > 1).
664+ acq_function_sequence: A list of acquisition functions to be optimized
665+ sequentially. Must be of length q>1, and requires sequential=True. Used
666+ for ensembling candidates from different acquisition functions. If
667+ omitted, use `acq_function` to generate all `q` candidates.
630668 ic_generator: Function for generating initial conditions. Not needed when
631669 `batch_initial_conditions` are provided. Defaults to
632670 `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
@@ -689,6 +727,7 @@ def optimize_acqf(
689727 return_full_tree = return_full_tree ,
690728 retry_on_optimization_warning = retry_on_optimization_warning ,
691729 ic_gen_kwargs = ic_gen_kwargs ,
730+ acq_function_sequence = acq_function_sequence ,
692731 )
693732 return _optimize_acqf (opt_inputs = opt_acqf_inputs )
694733
0 commit comments