|
9 | 9 | from botorch.acquisition.objective import PosteriorTransform |
10 | 10 | from botorch.models.model import Model |
11 | 11 | from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model |
12 | | -from botorch.utils.transforms import t_batch_mode_transform |
| 12 | +from botorch.utils.transforms import is_ensemble, t_batch_mode_transform |
13 | 13 | from torch import Tensor |
14 | 14 |
|
15 | 15 |
|
@@ -42,45 +42,88 @@ def __init__( |
42 | 42 | a PosteriorTransform that transforms the multi-output posterior into a |
43 | 43 | single-output posterior is required. |
44 | 44 | """ |
45 | | - if model._is_fully_bayesian: |
46 | | - raise NotImplementedError( |
47 | | - "PathwiseThompsonSampling is not supported for fully Bayesian models", |
48 | | - ) |
49 | 45 |
|
50 | 46 | super().__init__(model=model) |
51 | 47 | self.batch_size: int | None = None |
52 | 48 |
|
53 | | - def redraw(self) -> None: |
| 49 | + def redraw(self, batch_size: int) -> None: |
| 50 | + sample_shape = (batch_size,) |
54 | 51 | self.samples = get_matheron_path_model( |
55 | | - model=self.model, sample_shape=torch.Size([self.batch_size]) |
| 52 | + model=self.model, sample_shape=torch.Size(sample_shape) |
56 | 53 | ) |
| 54 | + if is_ensemble(self.model): |
| 55 | + # the ensembling dimension is assumed to be part of the batch shape |
| 56 | + # could add a dedicated proporty to keep track of the ensembling dimension |
| 57 | + # i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP |
| 58 | + model_batch_shape = self.model.batch_shape |
| 59 | + if len(model_batch_shape) > 1: |
| 60 | + raise NotImplementedError( |
| 61 | + "Ensemble models with more than one ensemble dimension are not " |
| 62 | + "yet supported." |
| 63 | + ) |
| 64 | + num_ensemble = model_batch_shape[0] |
| 65 | + self.ensemble_indices = torch.randint( |
| 66 | + 0, num_ensemble, (*sample_shape, 1, self.model.num_outputs) |
| 67 | + ) |
57 | 68 |
|
58 | 69 | @t_batch_mode_transform() |
59 | 70 | def forward(self, X: Tensor) -> Tensor: |
60 | 71 | r"""Evaluate the pathwise posterior sample draws on the candidate set X. |
61 | 72 |
|
62 | 73 | Args: |
63 | | - X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points. |
| 74 | + X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points. |
64 | 75 |
|
65 | 76 | Returns: |
66 | | - A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of |
67 | | - evaluations on the posterior sample draws. |
| 77 | + A `batch_shape [x m]`-dim tensor of evaluations on the posterior sample |
| 78 | + draws, where `m` is the number of outputs of the model. |
68 | 79 | """ |
69 | 80 | batch_size = X.shape[-2] |
70 | 81 | q_dim = -2 |
71 | | - |
72 | 82 | # batch_shape x q x 1 x d |
73 | 83 | X = X.unsqueeze(-2) |
74 | 84 | if self.batch_size is None: |
75 | 85 | self.batch_size = batch_size |
76 | | - self.redraw() |
| 86 | + self.redraw(batch_size=batch_size) |
77 | 87 | elif self.batch_size != batch_size: |
78 | 88 | raise ValueError( |
79 | 89 | BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size) |
80 | 90 | ) |
81 | | - |
82 | | - # posterior_values.shape post-squeeze: |
| 91 | + # batch_shape x q [x num_ensembles] x 1 x m |
| 92 | + posterior_values = self.samples(X) |
| 93 | + # batch_shape x q [x num_ensembles] x m |
| 94 | + posterior_values = posterior_values.squeeze(-2) |
83 | 95 | # batch_shape x q x m |
84 | | - posterior_values = self.samples(X).squeeze(-2) |
85 | | - # sum over batch dim and squeeze num_objectives dim (-1) |
86 | | - return posterior_values.sum(q_dim).squeeze(-1) |
| 96 | + posterior_values = self.select_from_ensemble_models(values=posterior_values) |
| 97 | + # NOTE: can leverage batched L-BFGS computation instead of summing in the future |
| 98 | + # sum over batch dim and squeeze num_objectives dim (-1): batch_shape [x m] |
| 99 | + acqf_vals = posterior_values.sum(q_dim).squeeze(-1) |
| 100 | + return acqf_vals |
| 101 | + |
| 102 | + def select_from_ensemble_models(self, values: Tensor): |
| 103 | + """Subselecting a value associated with a single sample in the ensemble for each |
| 104 | + element of samples that is not associated with an ensemble dimension. NOTE: uses |
| 105 | + `self.model` and `is_ensemble` to determine whether or not an ensembling |
| 106 | + dimension is present. |
| 107 | +
|
| 108 | + Args: |
| 109 | + values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + A`batch_shape x num_draws x q x m`-dim where each element was chosen |
| 113 | + independently randomly from the ensemble dimension. |
| 114 | + """ |
| 115 | + if not is_ensemble(self.model): |
| 116 | + return values |
| 117 | + |
| 118 | + ensemble_dim = -2 |
| 119 | + # `ensemble_indices` are fixed so that the acquisition function becomes |
| 120 | + # deterministic for the same input and can be optimized with LBFGS. |
| 121 | + # ensemble indices have shape num_paths x 1 x m |
| 122 | + index = self.ensemble_indices |
| 123 | + input_batch_shape = values.shape[:-3] |
| 124 | + index = index.expand(*input_batch_shape, *index.shape) |
| 125 | + # samples is batch_shape x q x num_ensemble x m |
| 126 | + values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index) |
| 127 | + return values_wo_ensemble.squeeze( |
| 128 | + ensemble_dim |
| 129 | + ) # removing the ensemble dimension |
0 commit comments