|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +from typing import Callable, List, Optional, Union |
| 7 | + |
| 8 | +import torch |
| 9 | +from botorch.acquisition.logei import qLogNoisyExpectedImprovement, TAU_MAX, TAU_RELU |
| 10 | +from botorch.acquisition.multi_objective.monte_carlo import ( |
| 11 | + MultiObjectiveMCAcquisitionFunction, |
| 12 | +) |
| 13 | +from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective |
| 14 | +from botorch.acquisition.objective import GenericMCObjective |
| 15 | +from botorch.models.model import Model |
| 16 | +from botorch.posteriors.fully_bayesian import MCMC_DIM |
| 17 | +from botorch.sampling.base import MCSampler |
| 18 | +from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization |
| 19 | +from botorch.utils.sampling import sample_simplex |
| 20 | +from botorch.utils.transforms import is_ensemble |
| 21 | +from torch import Tensor |
| 22 | + |
| 23 | + |
| 24 | +class qLogNParEGO(qLogNoisyExpectedImprovement, MultiObjectiveMCAcquisitionFunction): |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + model: Model, |
| 28 | + X_baseline: Tensor, |
| 29 | + scalarization_weights: Optional[Tensor] = None, |
| 30 | + sampler: Optional[MCSampler] = None, |
| 31 | + objective: Optional[MCMultiOutputObjective] = None, |
| 32 | + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, |
| 33 | + X_pending: Optional[Tensor] = None, |
| 34 | + eta: Union[Tensor, float] = 1e-3, |
| 35 | + fat: bool = True, |
| 36 | + prune_baseline: bool = False, |
| 37 | + cache_root: bool = True, |
| 38 | + tau_relu: float = TAU_RELU, |
| 39 | + tau_max: float = TAU_MAX, |
| 40 | + ) -> None: |
| 41 | + r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function |
| 42 | + utilizes qLogNEI to compute the expected improvement over Chebyshev |
| 43 | + scalarization of the objectives. |
| 44 | +
|
| 45 | + This is adapted from qNParEGO proposed in [Daulton2020qehvi]_ to utilize |
| 46 | + log-improvement acquisition functions of [Ament2023logei]_. See [Knowles2005]_ |
| 47 | + for the original ParEGO algorithm. |
| 48 | +
|
| 49 | + This implementation assumes maximization of all objectives. If any of the model |
| 50 | + outputs are to be minimized, either an `objective` should be used to negate the |
| 51 | + model outputs or the `scalarization_weights` should be provided with negative |
| 52 | + weights for the outputs to be minimized. |
| 53 | +
|
| 54 | + Args: |
| 55 | + model: A fitted multi-output model, producing outputs for `m` objectives |
| 56 | + and any number of outcome constraints. |
| 57 | + NOTE: The model posterior must have a `mean` attribute. |
| 58 | + X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points |
| 59 | + that have already been observed. These points are considered as |
| 60 | + the potential best design point. |
| 61 | + scalarization_weights: A `m`-dim Tensor of weights to be used in the |
| 62 | + Chebyshev scalarization. If omitted, samples from the unit simplex. |
| 63 | + sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` |
| 64 | + more details. |
| 65 | + objective: The MultiOutputMCAcquisitionObjective under which the samples are |
| 66 | + evaluated before applying Chebyshev scalarization. |
| 67 | + Defaults to `IdentityMultiOutputObjective()`. |
| 68 | + constraints: A list of constraint callables which map a Tensor of posterior |
| 69 | + samples of dimension `sample_shape x batch-shape x q x m'`-dim to a |
| 70 | + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints |
| 71 | + are satisfied if `constraint(samples) < 0`. |
| 72 | + X_pending: A `batch_shape x q' x d`-dim Tensor of `q'` design points |
| 73 | + that have points that have been submitted for function evaluation |
| 74 | + but have not yet been evaluated. Concatenated into `X` upon |
| 75 | + forward call. Copied and set to have no gradient. |
| 76 | + eta: Temperature parameter(s) governing the smoothness of the sigmoid |
| 77 | + approximation to the constraint indicators. See the docs of |
| 78 | + `compute_(log_)smoothed_constraint_indicator` for details. |
| 79 | + fat: Toggles the logarithmic / linear asymptotic behavior of the smooth |
| 80 | + approximation to the ReLU. |
| 81 | + prune_baseline: If True, remove points in `X_baseline` that are |
| 82 | + highly unlikely to be the best point. This can significantly |
| 83 | + improve performance and is generally recommended. In order to |
| 84 | + customize pruning parameters, instead manually call |
| 85 | + `botorch.acquisition.utils.prune_inferior_points` on `X_baseline` |
| 86 | + before instantiating the acquisition function. |
| 87 | + cache_root: A boolean indicating whether to cache the root |
| 88 | + decomposition over `X_baseline` and use low-rank updates. |
| 89 | + tau_max: Temperature parameter controlling the sharpness of the smooth |
| 90 | + approximations to max. |
| 91 | + tau_relu: Temperature parameter controlling the sharpness of the smooth |
| 92 | + approximations to ReLU. |
| 93 | + """ |
| 94 | + MultiObjectiveMCAcquisitionFunction.__init__( |
| 95 | + self, |
| 96 | + model=model, |
| 97 | + sampler=sampler, |
| 98 | + objective=objective, |
| 99 | + constraints=constraints, |
| 100 | + eta=eta, |
| 101 | + ) |
| 102 | + org_objective = self.objective |
| 103 | + # Create the composite objective. |
| 104 | + with torch.no_grad(): |
| 105 | + Y_baseline = org_objective(model.posterior(X_baseline).mean) |
| 106 | + if is_ensemble(model): |
| 107 | + Y_baseline = torch.mean(Y_baseline, dim=MCMC_DIM) |
| 108 | + scalarization_weights = ( |
| 109 | + scalarization_weights |
| 110 | + if scalarization_weights is not None |
| 111 | + else sample_simplex( |
| 112 | + d=Y_baseline.shape[-1], device=X_baseline.device, dtype=X_baseline.dtype |
| 113 | + ).view(-1) |
| 114 | + ) |
| 115 | + chebyshev_scalarization = get_chebyshev_scalarization( |
| 116 | + weights=scalarization_weights, |
| 117 | + Y=Y_baseline, |
| 118 | + ) |
| 119 | + composite_objective = GenericMCObjective( |
| 120 | + objective=lambda samples, X=None: chebyshev_scalarization( |
| 121 | + org_objective(samples=samples, X=X), X=X |
| 122 | + ), |
| 123 | + ) |
| 124 | + qLogNoisyExpectedImprovement.__init__( |
| 125 | + self, |
| 126 | + model=model, |
| 127 | + X_baseline=X_baseline, |
| 128 | + sampler=sampler, |
| 129 | + # This overwrites self.objective with the composite objective. |
| 130 | + objective=composite_objective, |
| 131 | + X_pending=X_pending, |
| 132 | + constraints=constraints, |
| 133 | + eta=eta, |
| 134 | + fat=fat, |
| 135 | + prune_baseline=prune_baseline, |
| 136 | + cache_root=cache_root, |
| 137 | + tau_max=tau_max, |
| 138 | + tau_relu=tau_relu, |
| 139 | + ) |
| 140 | + # Set these after __init__ calls so that they're not overwritten / deleted. |
| 141 | + # These are intended mainly for easier debugging & transparency. |
| 142 | + self._org_objective: MCMultiOutputObjective = org_objective |
| 143 | + self.chebyshev_scalarization: Callable[[Tensor, Optional[Tensor]], Tensor] = ( |
| 144 | + chebyshev_scalarization |
| 145 | + ) |
| 146 | + self.scalarization_weights: Tensor = scalarization_weights |
| 147 | + self.Y_baseline: Tensor = Y_baseline |
0 commit comments