@@ -271,13 +271,15 @@ def gen_batch_initial_conditions(
271271 fixed_features: A map `{feature_index: value}` for features that
272272 should be fixed to a particular value during generation.
273273 options: Options for initial condition generation. For valid options see
274- `initialize_q_batch` and `initialize_q_batch_nonneg`. If `options`
275- contains a `nonnegative=True` entry, then `acq_function` is
276- assumed to be non-negative (useful when using custom acquisition
277- functions). In addition, an "init_batch_limit" option can be passed
278- to specify the batch limit for the initialization. This is useful
279- for avoiding memory limits when computing the batch posterior over
280- raw samples.
274+ `initialize_q_batch_topn`, `initialize_q_batch_nonneg`, and
275+ `initialize_q_batch`. If `options` contains a `topn=True` then
276+ `initialize_q_batch_topn` will be used. Else if `options` contains a
277+ `nonnegative=True` entry, then `acq_function` is assumed to be
278+ non-negative (useful when using custom acquisition functions).
279+ `initialize_q_batch` will be used otherwise. In addition, an
280+ "init_batch_limit" option can be passed to specify the batch limit
281+ for the initialization. This is useful for avoiding memory limits
282+ when computing the batch posterior over raw samples.
281283 inequality constraints: A list of tuples (indices, coefficients, rhs),
282284 with each tuple encoding an inequality constraint of the form
283285 `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
@@ -328,14 +330,24 @@ def gen_batch_initial_conditions(
328330 init_kwargs = {}
329331 device = bounds .device
330332 bounds_cpu = bounds .cpu ()
331- if "eta" in options :
332- init_kwargs ["eta" ] = options .get ("eta" )
333- if options .get ("nonnegative" ) or is_nonnegative (acq_function ):
333+
334+ if options .get ("topn" ):
335+ init_func = initialize_q_batch_topn
336+ init_func_opts = ["sorted" , "largest" ]
337+ elif options .get ("nonnegative" ) or is_nonnegative (acq_function ):
334338 init_func = initialize_q_batch_nonneg
335- if "alpha" in options :
336- init_kwargs ["alpha" ] = options .get ("alpha" )
339+ init_func_opts = ["alpha" , "eta" ]
337340 else :
338341 init_func = initialize_q_batch
342+ init_func_opts = ["eta" ]
343+
344+ for opt in init_func_opts :
345+ # default value of "largest" to "acq_function.maximize" if it exists
346+ if opt == "largest" and hasattr (acq_function , "maximize" ):
347+ init_kwargs [opt ] = acq_function .maximize
348+
349+ if opt in options :
350+ init_kwargs [opt ] = options .get (opt )
339351
340352 q = 1 if q is None else q
341353 # the dimension the samples are drawn from
@@ -363,7 +375,9 @@ def gen_batch_initial_conditions(
363375 X_rnd_nlzd = torch .rand (
364376 n , q , bounds_cpu .shape [- 1 ], dtype = bounds .dtype
365377 )
366- X_rnd = bounds_cpu [0 ] + (bounds_cpu [1 ] - bounds_cpu [0 ]) * X_rnd_nlzd
378+ X_rnd = unnormalize (
379+ X_rnd_nlzd , bounds_cpu , update_constant_bounds = False
380+ )
367381 else :
368382 X_rnd = sample_q_batches_from_polytope (
369383 n = n ,
@@ -375,7 +389,8 @@ def gen_batch_initial_conditions(
375389 equality_constraints = equality_constraints ,
376390 inequality_constraints = inequality_constraints ,
377391 )
378- # sample points around best
392+
393+ # sample additional points around best
379394 if sample_around_best :
380395 X_best_rnd = sample_points_around_best (
381396 acq_function = acq_function ,
@@ -395,6 +410,8 @@ def gen_batch_initial_conditions(
395410 )
396411 # Keep X on CPU for consistency & to limit GPU memory usage.
397412 X_rnd = fix_features (X_rnd , fixed_features = fixed_features ).cpu ()
413+
414+ # Append the fixed fantasies to the randomly generated points
398415 if fixed_X_fantasies is not None :
399416 if (d_f := fixed_X_fantasies .shape [- 1 ]) != (d_r := X_rnd .shape [- 1 ]):
400417 raise BotorchTensorDimensionError (
@@ -411,6 +428,9 @@ def gen_batch_initial_conditions(
411428 ],
412429 dim = - 2 ,
413430 )
431+
432+ # Evaluate the acquisition function on `X_rnd` using `batch_limit`
433+ # sized chunks.
414434 with torch .no_grad ():
415435 if batch_limit is None :
416436 batch_limit = X_rnd .shape [0 ]
@@ -423,16 +443,22 @@ def gen_batch_initial_conditions(
423443 ],
424444 dim = 0 ,
425445 )
446+
447+ # Downselect the initial conditions based on the acquisition function values
426448 batch_initial_conditions , _ = init_func (
427449 X = X_rnd , acq_vals = acq_vals , n = num_restarts , ** init_kwargs
428450 )
429451 batch_initial_conditions = batch_initial_conditions .to (device = device )
452+
453+ # Return the initial conditions if no warnings were raised
430454 if not any (issubclass (w .category , BadInitialCandidatesWarning ) for w in ws ):
431455 return batch_initial_conditions
456+
432457 if factor < max_factor :
433458 factor += 1
434459 if seed is not None :
435460 seed += 1 # make sure to sample different X_rnd
461+
436462 warnings .warn (
437463 "Unable to find non-zero acquisition function values - initial conditions "
438464 "are being selected randomly." ,
@@ -1057,6 +1083,56 @@ def initialize_q_batch_nonneg(
10571083 return X [idcs ], acq_vals [idcs ]
10581084
10591085
1086+ def initialize_q_batch_topn (
1087+ X : Tensor , acq_vals : Tensor , n : int , largest : bool = True , sorted : bool = True
1088+ ) -> tuple [Tensor , Tensor ]:
1089+ r"""Take the top `n` initial conditions for candidate generation.
1090+
1091+ Args:
1092+ X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
1093+ feature space. Typically, these are generated using qMC.
1094+ acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
1095+ is the value of the batch acquisition function to be maximized.
1096+ n: The number of initial condition to be generated. Must be less than `b`.
1097+
1098+ Returns:
1099+ - An `n x q x d` tensor of `n` `q`-batch initial conditions.
1100+ - An `n` tensor of the corresponding acquisition values.
1101+
1102+ Example:
1103+ >>> # To get `n=10` starting points of q-batch size `q=3`
1104+ >>> # for model with `d=6`:
1105+ >>> qUCB = qUpperConfidenceBound(model, beta=0.1)
1106+ >>> X_rnd = torch.rand(500, 3, 6)
1107+ >>> X_init, acq_init = initialize_q_batch_topn(
1108+ ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
1109+ ... )
1110+
1111+ """
1112+ n_samples = X .shape [0 ]
1113+ if n > n_samples :
1114+ raise RuntimeError (
1115+ f"n ({ n } ) cannot be larger than the number of "
1116+ f"provided samples ({ n_samples } )"
1117+ )
1118+ elif n == n_samples :
1119+ return X , acq_vals
1120+
1121+ Ystd = acq_vals .std (dim = 0 )
1122+ if torch .any (Ystd == 0 ):
1123+ warnings .warn (
1124+ "All acquisition values for raw samples points are the same for "
1125+ "at least one batch. Choosing initial conditions at random." ,
1126+ BadInitialCandidatesWarning ,
1127+ stacklevel = 3 ,
1128+ )
1129+ idcs = torch .randperm (n = n_samples , device = X .device )[:n ]
1130+ return X [idcs ], acq_vals [idcs ]
1131+
1132+ topk_out , topk_idcs = acq_vals .topk (n , largest = largest , sorted = sorted )
1133+ return X [topk_idcs ], topk_out
1134+
1135+
10601136def sample_points_around_best (
10611137 acq_function : AcquisitionFunction ,
10621138 n_discrete_points : int ,
0 commit comments