|
23 | 23 |
|
24 | 24 | import torch |
25 | 25 | from botorch import settings |
26 | | -from botorch.acquisition.acquisition import AcquisitionFunction |
| 26 | +from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin |
27 | 27 | from botorch.exceptions.errors import UnsupportedError |
28 | 28 |
|
29 | 29 | from botorch.models.model import Model |
30 | 30 | from botorch.models.model_list_gp_regression import ModelListGP |
31 | 31 | from botorch.models.utils import fantasize as fantasize_flag |
32 | 32 | from botorch.posteriors.gpytorch import GPyTorchPosterior |
33 | | -from botorch.sampling.samplers import SobolQMCNormalSampler |
| 33 | +from botorch.sampling.normal import SobolQMCNormalSampler |
34 | 34 | from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform |
35 | 35 | from torch import Tensor |
36 | 36 | from torch.distributions import Normal |
37 | 37 |
|
38 | 38 |
|
39 | | -class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction): |
| 39 | +class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction, MCSamplerMixin): |
40 | 40 | r"""Abstract base class for the lower bound multi-objective entropy search |
41 | 41 | acquisition functions. |
42 | 42 | """ |
@@ -74,6 +74,8 @@ def __init__( |
74 | 74 | estimate. |
75 | 75 | """ |
76 | 76 | super().__init__(model=model) |
| 77 | + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])) |
| 78 | + MCSamplerMixin.__init__(self, sampler=sampler) |
77 | 79 | # Batch GP models (e.g. fantasized models) are not currently supported |
78 | 80 | if isinstance(model, ModelListGP): |
79 | 81 | train_X = model.models[0].train_inputs[0] |
@@ -119,9 +121,6 @@ def __init__( |
119 | 121 | + "." |
120 | 122 | ) |
121 | 123 |
|
122 | | - self.sampler = SobolQMCNormalSampler( |
123 | | - num_samples=num_samples, collapse_batch_dims=True |
124 | | - ) |
125 | 124 | self.set_X_pending(X_pending) |
126 | 125 |
|
127 | 126 | @abstractmethod |
@@ -420,7 +419,7 @@ def _compute_monte_carlo_variables( |
420 | 419 | samples. |
421 | 420 | """ |
422 | 421 | # `num_mc_samples x batch_shape x q x num_pareto_samples x 1 x M` |
423 | | - samples = self.sampler(posterior) |
| 422 | + samples = self.get_posterior_samples(posterior) |
424 | 423 |
|
425 | 424 | # `num_mc_samples x batch_shape x q x num_pareto_samples` |
426 | 425 | if self.model.num_outputs == 1: |
|
0 commit comments