Skip to content

Commit c1eb255

Browse files
CompRhysfacebook-github-bot
authored andcommitted
Add TopK downselection for initial batch generation. (#2636)
Summary: ## Motivation In order to get facebook/Ax#2938 over the line with initial candidate generation that obey the constraints we want to use the existing tooling within `botorch`. The hard coded logic currently in Ax uses topk to downselect the sobol samples. To make a change there that will not impact existing users we then need to implement topk downselection in `botorch`. Pull Request resolved: #2636 Test Plan: TODO: - [x] add tests for initialize_q_batch_topk ## Related PRs facebook/Ax#2938. (#2610 was initially intended to play part of this solution but then I realized that the pattern I wanted to use was conflating repeats and the batch dimension.) Reviewed By: Balandat Differential Revision: D66413947 Pulled By: saitcakmak fbshipit-source-id: 39e71f5cc0468d554419fa25dd545d9ee25289dc
1 parent 0d5e131 commit c1eb255

File tree

6 files changed

+476
-268
lines changed

6 files changed

+476
-268
lines changed

botorch/optim/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
LinearHomotopySchedule,
2323
LogLinearHomotopySchedule,
2424
)
25-
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
25+
from botorch.optim.initializers import (
26+
initialize_q_batch,
27+
initialize_q_batch_nonneg,
28+
initialize_q_batch_topn,
29+
)
2630
from botorch.optim.optimize import (
2731
gen_batch_initial_conditions,
2832
optimize_acqf,
@@ -43,6 +47,7 @@
4347
"gen_batch_initial_conditions",
4448
"initialize_q_batch",
4549
"initialize_q_batch_nonneg",
50+
"initialize_q_batch_topn",
4651
"OptimizationResult",
4752
"OptimizationStatus",
4853
"optimize_acqf",

botorch/optim/initializers.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,15 @@ def gen_batch_initial_conditions(
271271
fixed_features: A map `{feature_index: value}` for features that
272272
should be fixed to a particular value during generation.
273273
options: Options for initial condition generation. For valid options see
274-
`initialize_q_batch` and `initialize_q_batch_nonneg`. If `options`
275-
contains a `nonnegative=True` entry, then `acq_function` is
276-
assumed to be non-negative (useful when using custom acquisition
277-
functions). In addition, an "init_batch_limit" option can be passed
278-
to specify the batch limit for the initialization. This is useful
279-
for avoiding memory limits when computing the batch posterior over
280-
raw samples.
274+
`initialize_q_batch_topn`, `initialize_q_batch_nonneg`, and
275+
`initialize_q_batch`. If `options` contains a `topn=True` then
276+
`initialize_q_batch_topn` will be used. Else if `options` contains a
277+
`nonnegative=True` entry, then `acq_function` is assumed to be
278+
non-negative (useful when using custom acquisition functions).
279+
`initialize_q_batch` will be used otherwise. In addition, an
280+
"init_batch_limit" option can be passed to specify the batch limit
281+
for the initialization. This is useful for avoiding memory limits
282+
when computing the batch posterior over raw samples.
281283
inequality constraints: A list of tuples (indices, coefficients, rhs),
282284
with each tuple encoding an inequality constraint of the form
283285
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
@@ -328,14 +330,24 @@ def gen_batch_initial_conditions(
328330
init_kwargs = {}
329331
device = bounds.device
330332
bounds_cpu = bounds.cpu()
331-
if "eta" in options:
332-
init_kwargs["eta"] = options.get("eta")
333-
if options.get("nonnegative") or is_nonnegative(acq_function):
333+
334+
if options.get("topn"):
335+
init_func = initialize_q_batch_topn
336+
init_func_opts = ["sorted", "largest"]
337+
elif options.get("nonnegative") or is_nonnegative(acq_function):
334338
init_func = initialize_q_batch_nonneg
335-
if "alpha" in options:
336-
init_kwargs["alpha"] = options.get("alpha")
339+
init_func_opts = ["alpha", "eta"]
337340
else:
338341
init_func = initialize_q_batch
342+
init_func_opts = ["eta"]
343+
344+
for opt in init_func_opts:
345+
# default value of "largest" to "acq_function.maximize" if it exists
346+
if opt == "largest" and hasattr(acq_function, "maximize"):
347+
init_kwargs[opt] = acq_function.maximize
348+
349+
if opt in options:
350+
init_kwargs[opt] = options.get(opt)
339351

340352
q = 1 if q is None else q
341353
# the dimension the samples are drawn from
@@ -363,7 +375,9 @@ def gen_batch_initial_conditions(
363375
X_rnd_nlzd = torch.rand(
364376
n, q, bounds_cpu.shape[-1], dtype=bounds.dtype
365377
)
366-
X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd
378+
X_rnd = unnormalize(
379+
X_rnd_nlzd, bounds_cpu, update_constant_bounds=False
380+
)
367381
else:
368382
X_rnd = sample_q_batches_from_polytope(
369383
n=n,
@@ -375,7 +389,8 @@ def gen_batch_initial_conditions(
375389
equality_constraints=equality_constraints,
376390
inequality_constraints=inequality_constraints,
377391
)
378-
# sample points around best
392+
393+
# sample additional points around best
379394
if sample_around_best:
380395
X_best_rnd = sample_points_around_best(
381396
acq_function=acq_function,
@@ -395,6 +410,8 @@ def gen_batch_initial_conditions(
395410
)
396411
# Keep X on CPU for consistency & to limit GPU memory usage.
397412
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()
413+
414+
# Append the fixed fantasies to the randomly generated points
398415
if fixed_X_fantasies is not None:
399416
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
400417
raise BotorchTensorDimensionError(
@@ -411,6 +428,9 @@ def gen_batch_initial_conditions(
411428
],
412429
dim=-2,
413430
)
431+
432+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
433+
# sized chunks.
414434
with torch.no_grad():
415435
if batch_limit is None:
416436
batch_limit = X_rnd.shape[0]
@@ -423,16 +443,22 @@ def gen_batch_initial_conditions(
423443
],
424444
dim=0,
425445
)
446+
447+
# Downselect the initial conditions based on the acquisition function values
426448
batch_initial_conditions, _ = init_func(
427449
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
428450
)
429451
batch_initial_conditions = batch_initial_conditions.to(device=device)
452+
453+
# Return the initial conditions if no warnings were raised
430454
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
431455
return batch_initial_conditions
456+
432457
if factor < max_factor:
433458
factor += 1
434459
if seed is not None:
435460
seed += 1 # make sure to sample different X_rnd
461+
436462
warnings.warn(
437463
"Unable to find non-zero acquisition function values - initial conditions "
438464
"are being selected randomly.",
@@ -1057,6 +1083,56 @@ def initialize_q_batch_nonneg(
10571083
return X[idcs], acq_vals[idcs]
10581084

10591085

1086+
def initialize_q_batch_topn(
1087+
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
1088+
) -> tuple[Tensor, Tensor]:
1089+
r"""Take the top `n` initial conditions for candidate generation.
1090+
1091+
Args:
1092+
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
1093+
feature space. Typically, these are generated using qMC.
1094+
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
1095+
is the value of the batch acquisition function to be maximized.
1096+
n: The number of initial condition to be generated. Must be less than `b`.
1097+
1098+
Returns:
1099+
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
1100+
- An `n` tensor of the corresponding acquisition values.
1101+
1102+
Example:
1103+
>>> # To get `n=10` starting points of q-batch size `q=3`
1104+
>>> # for model with `d=6`:
1105+
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
1106+
>>> X_rnd = torch.rand(500, 3, 6)
1107+
>>> X_init, acq_init = initialize_q_batch_topn(
1108+
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
1109+
... )
1110+
1111+
"""
1112+
n_samples = X.shape[0]
1113+
if n > n_samples:
1114+
raise RuntimeError(
1115+
f"n ({n}) cannot be larger than the number of "
1116+
f"provided samples ({n_samples})"
1117+
)
1118+
elif n == n_samples:
1119+
return X, acq_vals
1120+
1121+
Ystd = acq_vals.std(dim=0)
1122+
if torch.any(Ystd == 0):
1123+
warnings.warn(
1124+
"All acquisition values for raw samples points are the same for "
1125+
"at least one batch. Choosing initial conditions at random.",
1126+
BadInitialCandidatesWarning,
1127+
stacklevel=3,
1128+
)
1129+
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
1130+
return X[idcs], acq_vals[idcs]
1131+
1132+
topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted)
1133+
return X[topk_idcs], topk_out
1134+
1135+
10601136
def sample_points_around_best(
10611137
acq_function: AcquisitionFunction,
10621138
n_discrete_points: int,

botorch/utils/feasible_volume.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import botorch.models.model as model
1212
import torch
1313
from botorch.logging import _get_logger
14-
from botorch.utils.sampling import manual_seed
14+
from botorch.utils.sampling import manual_seed, unnormalize
1515
from torch import Tensor
1616

1717

@@ -164,9 +164,10 @@ def estimate_feasible_volume(
164164
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
165165

166166
with manual_seed(seed=seed):
167-
box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(
167+
samples_nlzd = torch.rand(
168168
(nsample_feature, bounds.size(1)), dtype=dtype, device=device
169169
)
170+
box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False)
170171

171172
features, p_feature = get_feasible_samples(
172173
samples=box_samples, inequality_constraints=inequality_constraints

botorch/utils/sampling.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,12 @@ def draw_sobol_samples(
9898
batch_shape = batch_shape or torch.Size()
9999
batch_size = int(torch.prod(torch.tensor(batch_shape)))
100100
d = bounds.shape[-1]
101-
lower = bounds[0]
102-
rng = bounds[1] - bounds[0]
103101
sobol_engine = SobolEngine(q * d, scramble=True, seed=seed)
104-
samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype)
105-
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device)
102+
samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype)
103+
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device)
106104
if batch_shape != torch.Size():
107105
samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1)
108-
return lower + rng * samples_raw
106+
return unnormalize(samples_raw, bounds, update_constant_bounds=False)
109107

110108

111109
def draw_sobol_normal_samples(

botorch/utils/transforms.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,18 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor:
6666
return bounds
6767

6868

69-
def normalize(X: Tensor, bounds: Tensor) -> Tensor:
69+
def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor:
7070
r"""Min-max normalize X w.r.t. the provided bounds.
7171
72-
NOTE: If the upper and lower bounds are identical for a dimension, that dimension
73-
will not be scaled. Such dimensions will only be shifted as
74-
`new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues.
75-
7672
Args:
7773
X: `... x d` tensor of data
7874
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
7975
columns.
76+
update_constant_bounds: If `True`, update the constant bounds in order to
77+
avoid division by zero issues. When the upper and lower bounds are
78+
identical for a dimension, that dimension will not be scaled. Such
79+
dimensions will only be shifted as
80+
`new_X[..., i] = X[..., i] - bounds[0, i]`.
8081
8182
Returns:
8283
A `... x d`-dim tensor of normalized data, given by
@@ -89,21 +90,27 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor:
8990
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
9091
>>> X_normalized = normalize(X, bounds)
9192
"""
92-
bounds = _update_constant_bounds(bounds=bounds)
93+
bounds = (
94+
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
95+
)
9396
return (X - bounds[0]) / (bounds[1] - bounds[0])
9497

9598

96-
def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
99+
def unnormalize(
100+
X: Tensor, bounds: Tensor, update_constant_bounds: bool = True
101+
) -> Tensor:
97102
r"""Un-normalizes X w.r.t. the provided bounds.
98103
99-
NOTE: If the upper and lower bounds are identical for a dimension, that dimension
100-
will not be scaled. Such dimensions will only be shifted as
101-
`new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`.
102-
103104
Args:
104105
X: `... x d` tensor of data
105106
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
106107
columns.
108+
update_constant_bounds: If `True`, update the constant bounds in order to
109+
avoid division by zero issues. When the upper and lower bounds are
110+
identical for a dimension, that dimension will not be scaled. Such
111+
dimensions will only be shifted as
112+
`new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of
113+
the behavior of `normalize` when `update_constant_bounds=True`.
107114
108115
Returns:
109116
A `... x d`-dim tensor of unnormalized data, given by
@@ -116,7 +123,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
116123
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
117124
>>> X = unnormalize(X_normalized, bounds)
118125
"""
119-
bounds = _update_constant_bounds(bounds=bounds)
126+
bounds = (
127+
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
128+
)
120129
return X * (bounds[1] - bounds[0]) + bounds[0]
121130

122131

0 commit comments

Comments
 (0)