Skip to content

Commit 4190f74

Browse files
CompRhysfacebook-github-bot
authored andcommitted
Add ability to mix batch initial conditions and internal IC generation (#2610)
Summary: ## Motivation See #2609 ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2610 Test Plan: Basic testing of the code is easy the challenge is working out what the run on implications might be, will this break people's code? ## Related PRs facebook/Ax#2938 Reviewed By: Balandat Differential Revision: D66102868 Pulled By: saitcakmak fbshipit-source-id: b3491581a205b0fbe62edd670510e95f13e08177
1 parent a1763a1 commit 4190f74

File tree

3 files changed

+296
-60
lines changed

3 files changed

+296
-60
lines changed

botorch/optim/optimize.py

Lines changed: 165 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,43 @@ def __post_init__(self) -> None:
109109
"3-dimensional. Its shape is "
110110
f"{batch_initial_conditions_shape}."
111111
)
112+
112113
if batch_initial_conditions_shape[-1] != d:
113114
raise ValueError(
114115
f"batch_initial_conditions.shape[-1] must be {d}. The "
115116
f"shape is {batch_initial_conditions_shape}."
116117
)
117118

119+
if len(batch_initial_conditions_shape) == 2:
120+
warnings.warn(
121+
"If using a 2-dim `batch_initial_conditions` botorch will "
122+
"default to old behavior of ignoring `num_restarts` and just "
123+
"use the given `batch_initial_conditions` by setting "
124+
"`raw_samples` to None.",
125+
RuntimeWarning,
126+
stacklevel=3,
127+
)
128+
# Use object.__setattr__ to bypass immutability and set a value
129+
object.__setattr__(self, "raw_samples", None)
130+
131+
if (
132+
len(batch_initial_conditions_shape) == 3
133+
and batch_initial_conditions_shape[0] < self.num_restarts
134+
and batch_initial_conditions_shape[-2] != self.q
135+
):
136+
warnings.warn(
137+
"If using a 3-dim `batch_initial_conditions` where the "
138+
"first dimension is less than `num_restarts` and the second "
139+
"dimension is not equal to `q`, botorch will default to "
140+
"old behavior of ignoring `num_restarts` and just use the "
141+
"given `batch_initial_conditions` by setting `raw_samples` "
142+
"to None.",
143+
RuntimeWarning,
144+
stacklevel=3,
145+
)
146+
# Use object.__setattr__ to bypass immutability and set a value
147+
object.__setattr__(self, "raw_samples", None)
148+
118149
elif self.ic_generator is None:
119150
if self.nonlinear_inequality_constraints is not None:
120151
raise RuntimeError(
@@ -126,6 +157,7 @@ def __post_init__(self) -> None:
126157
"Must specify `raw_samples` when "
127158
"`batch_initial_conditions` is None`."
128159
)
160+
129161
if self.fixed_features is not None and any(
130162
(k < 0 for k in self.fixed_features)
131163
):
@@ -253,20 +285,49 @@ def _optimize_acqf_sequential_q(
253285
return candidates, torch.stack(acq_value_list)
254286

255287

288+
def _combine_initial_conditions(
289+
provided_initial_conditions: Tensor | None = None,
290+
generated_initial_conditions: Tensor | None = None,
291+
dim=0,
292+
) -> Tensor:
293+
if (
294+
provided_initial_conditions is not None
295+
and generated_initial_conditions is not None
296+
):
297+
return torch.cat(
298+
[provided_initial_conditions, generated_initial_conditions], dim=dim
299+
)
300+
elif provided_initial_conditions is not None:
301+
return provided_initial_conditions
302+
elif generated_initial_conditions is not None:
303+
return generated_initial_conditions
304+
else:
305+
raise ValueError(
306+
"Either `batch_initial_conditions` or `raw_samples` must be set."
307+
)
308+
309+
256310
def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
257311
options = opt_inputs.options or {}
258312

259-
initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
313+
required_num_restarts = opt_inputs.num_restarts
314+
provided_initial_conditions = opt_inputs.batch_initial_conditions
315+
generated_initial_conditions = None
260316

261-
if initial_conditions_provided:
262-
batch_initial_conditions = opt_inputs.batch_initial_conditions
263-
else:
264-
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265-
batch_initial_conditions = opt_inputs.get_ic_generator()(
317+
if (
318+
provided_initial_conditions is not None
319+
and len(provided_initial_conditions.shape) == 3
320+
):
321+
required_num_restarts -= provided_initial_conditions.shape[0]
322+
323+
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
324+
# pyre-ignore[28]: Unexpected keyword argument `acq_function`
325+
# to anonymous call.
326+
generated_initial_conditions = opt_inputs.get_ic_generator()(
266327
acq_function=opt_inputs.acq_function,
267328
bounds=opt_inputs.bounds,
268329
q=opt_inputs.q,
269-
num_restarts=opt_inputs.num_restarts,
330+
num_restarts=required_num_restarts,
270331
raw_samples=opt_inputs.raw_samples,
271332
fixed_features=opt_inputs.fixed_features,
272333
options=options,
@@ -275,6 +336,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
275336
**opt_inputs.ic_gen_kwargs,
276337
)
277338

339+
batch_initial_conditions = _combine_initial_conditions(
340+
provided_initial_conditions=provided_initial_conditions,
341+
generated_initial_conditions=generated_initial_conditions,
342+
)
343+
278344
batch_limit: int = options.get(
279345
"batch_limit",
280346
(
@@ -344,23 +410,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344410
first_warn_msg = (
345411
"Optimization failed in `gen_candidates_scipy` with the following "
346412
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
347-
"`batch_initial_conditions`, optimization will not be retried with "
348-
"new initial conditions and will proceed with the current solution."
349-
" Suggested remediation: Try again with different "
350-
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351-
if initial_conditions_provided
413+
"`batch_initial_conditions` larger than required `num_restarts`, "
414+
"optimization will not be retried with new initial conditions and "
415+
"will proceed with the current solution. Suggested remediation: "
416+
"Try again with different `batch_initial_conditions`, don't provide "
417+
"`batch_initial_conditions`, or increase `num_restarts`."
418+
if batch_initial_conditions is not None and required_num_restarts <= 0
352419
else "Optimization failed in `gen_candidates_scipy` with the following "
353420
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
354421
"set of initial conditions."
355422
)
356423
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)
357424

358-
if not initial_conditions_provided:
359-
batch_initial_conditions = opt_inputs.get_ic_generator()(
425+
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
426+
generated_initial_conditions = opt_inputs.get_ic_generator()(
360427
acq_function=opt_inputs.acq_function,
361428
bounds=opt_inputs.bounds,
362429
q=opt_inputs.q,
363-
num_restarts=opt_inputs.num_restarts,
430+
num_restarts=required_num_restarts,
364431
raw_samples=opt_inputs.raw_samples,
365432
fixed_features=opt_inputs.fixed_features,
366433
options=options,
@@ -369,6 +436,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
369436
**opt_inputs.ic_gen_kwargs,
370437
)
371438

439+
batch_initial_conditions = _combine_initial_conditions(
440+
provided_initial_conditions=provided_initial_conditions,
441+
generated_initial_conditions=generated_initial_conditions,
442+
)
443+
372444
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
373445

374446
optimization_warning_raised = any(
@@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
11771249
inequality_constraints: list[tuple[Tensor, Tensor, float]],
11781250
min_points: int,
11791251
max_tries: int = 100,
1180-
):
1252+
) -> Tensor:
11811253
"""Generate initial conditions for local search."""
11821254
device = discrete_choices[0].device
11831255
dtype = discrete_choices[0].dtype
@@ -1197,6 +1269,58 @@ def _gen_batch_initial_conditions_local_search(
11971269
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")
11981270

11991271

1272+
def _gen_starting_points_local_search(
1273+
discrete_choices: list[Tensor],
1274+
raw_samples: int,
1275+
batch_initial_conditions: Tensor,
1276+
X_avoid: Tensor,
1277+
inequality_constraints: list[tuple[Tensor, Tensor, float]],
1278+
min_points: int,
1279+
acq_function: AcquisitionFunction,
1280+
max_batch_size: int = 2048,
1281+
max_tries: int = 100,
1282+
) -> Tensor:
1283+
required_min_points = min_points
1284+
provided_X0 = None
1285+
generated_X0 = None
1286+
1287+
if batch_initial_conditions is not None:
1288+
provided_X0 = _filter_invalid(
1289+
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
1290+
)
1291+
provided_X0 = _filter_infeasible(
1292+
X=provided_X0, inequality_constraints=inequality_constraints
1293+
).unsqueeze(1)
1294+
required_min_points -= batch_initial_conditions.shape[0]
1295+
1296+
if required_min_points > 0:
1297+
generated_X0 = _gen_batch_initial_conditions_local_search(
1298+
discrete_choices=discrete_choices,
1299+
raw_samples=raw_samples,
1300+
X_avoid=X_avoid,
1301+
inequality_constraints=inequality_constraints,
1302+
min_points=min_points,
1303+
max_tries=max_tries,
1304+
)
1305+
1306+
# pick the best starting points
1307+
with torch.no_grad():
1308+
acqvals_init = _split_batch_eval_acqf(
1309+
acq_function=acq_function,
1310+
X=generated_X0.unsqueeze(1),
1311+
max_batch_size=max_batch_size,
1312+
).unsqueeze(-1)
1313+
1314+
generated_X0 = generated_X0[
1315+
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
1316+
]
1317+
1318+
return _combine_initial_conditions(
1319+
provided_initial_conditions=provided_X0 if provided_X0 is not None else None,
1320+
generated_initial_conditions=generated_X0 if generated_X0 is not None else None,
1321+
)
1322+
1323+
12001324
def optimize_acqf_discrete_local_search(
12011325
acq_function: AcquisitionFunction,
12021326
discrete_choices: list[Tensor],
@@ -1207,6 +1331,7 @@ def optimize_acqf_discrete_local_search(
12071331
X_avoid: Tensor | None = None,
12081332
batch_initial_conditions: Tensor | None = None,
12091333
max_batch_size: int = 2048,
1334+
max_tries: int = 100,
12101335
unique: bool = True,
12111336
) -> tuple[Tensor, Tensor]:
12121337
r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1363,8 @@ def optimize_acqf_discrete_local_search(
12381363
max_batch_size: The maximum number of choices to evaluate in batch.
12391364
A large limit can cause excessive memory usage if the model has
12401365
a large training set.
1366+
max_tries: Maximum number of iterations to try when generating initial
1367+
conditions.
12411368
unique: If True return unique choices, o/w choices may be repeated
12421369
(only relevant if `q > 1`).
12431370
@@ -1247,6 +1374,16 @@ def optimize_acqf_discrete_local_search(
12471374
- a `q x d`-dim tensor of generated candidates.
12481375
- an associated acquisition value.
12491376
"""
1377+
if batch_initial_conditions is not None:
1378+
if not (
1379+
len(batch_initial_conditions.shape) == 3
1380+
and batch_initial_conditions.shape[-2] == 1
1381+
):
1382+
raise ValueError(
1383+
"batch_initial_conditions must have shape `n x 1 x d` if "
1384+
f"given (received shape {batch_initial_conditions.shape})."
1385+
)
1386+
12501387
candidate_list = []
12511388
base_X_pending = acq_function.X_pending if q > 1 else None
12521389
base_X_avoid = X_avoid
@@ -1259,27 +1396,18 @@ def optimize_acqf_discrete_local_search(
12591396
inequality_constraints = inequality_constraints or []
12601397
for i in range(q):
12611398
# generate some starting points
1262-
if i == 0 and batch_initial_conditions is not None:
1263-
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1264-
X0 = _filter_infeasible(
1265-
X=X0, inequality_constraints=inequality_constraints
1266-
).unsqueeze(1)
1267-
else:
1268-
X_init = _gen_batch_initial_conditions_local_search(
1269-
discrete_choices=discrete_choices,
1270-
raw_samples=raw_samples,
1271-
X_avoid=X_avoid,
1272-
inequality_constraints=inequality_constraints,
1273-
min_points=num_restarts,
1274-
)
1275-
# pick the best starting points
1276-
with torch.no_grad():
1277-
acqvals_init = _split_batch_eval_acqf(
1278-
acq_function=acq_function,
1279-
X=X_init.unsqueeze(1),
1280-
max_batch_size=max_batch_size,
1281-
).unsqueeze(-1)
1282-
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
1399+
X0 = _gen_starting_points_local_search(
1400+
discrete_choices=discrete_choices,
1401+
raw_samples=raw_samples,
1402+
batch_initial_conditions=batch_initial_conditions,
1403+
X_avoid=X_avoid,
1404+
inequality_constraints=inequality_constraints,
1405+
min_points=num_restarts,
1406+
acq_function=acq_function,
1407+
max_batch_size=max_batch_size,
1408+
max_tries=max_tries,
1409+
)
1410+
batch_initial_conditions = None
12831411

12841412
# optimize from the best starting points
12851413
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)

botorch/optim/optimize_homotopy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
157157
"""
158158
shared_optimize_acqf_kwargs = {
159159
"num_restarts": num_restarts,
160-
"raw_samples": raw_samples,
161160
"inequality_constraints": inequality_constraints,
162161
"equality_constraints": equality_constraints,
163162
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
@@ -178,6 +177,7 @@ def optimize_acqf_homotopy(
178177

179178
for _ in range(q):
180179
candidates = batch_initial_conditions
180+
q_raw_samples = raw_samples
181181
homotopy.restart()
182182

183183
while not homotopy.should_stop:
@@ -187,10 +187,15 @@ def optimize_acqf_homotopy(
187187
q=1,
188188
options=options,
189189
batch_initial_conditions=candidates,
190+
raw_samples=q_raw_samples,
190191
**shared_optimize_acqf_kwargs,
191192
)
192193
homotopy.step()
193194

195+
# Set raw_samples to None such that pruned restarts are not repopulated
196+
# at each step in the homotopy.
197+
q_raw_samples = None
198+
194199
# Prune candidates
195200
candidates = prune_candidates(
196201
candidates=candidates.squeeze(1),
@@ -204,6 +209,7 @@ def optimize_acqf_homotopy(
204209
bounds=bounds,
205210
q=1,
206211
options=final_options,
212+
raw_samples=q_raw_samples,
207213
batch_initial_conditions=candidates,
208214
**shared_optimize_acqf_kwargs,
209215
)

0 commit comments

Comments
 (0)