Skip to content

Commit fb459a1

Browse files
committed
Update entropy search sampler
1 parent a5f8f19 commit fb459a1

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

botorch/acquisition/multi_objective/joint_entropy_search.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@
2323

2424
import torch
2525
from botorch import settings
26-
from botorch.acquisition.acquisition import AcquisitionFunction
26+
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
2727
from botorch.exceptions.errors import UnsupportedError
2828

2929
from botorch.models.model import Model
3030
from botorch.models.model_list_gp_regression import ModelListGP
3131
from botorch.models.utils import fantasize as fantasize_flag
3232
from botorch.posteriors.gpytorch import GPyTorchPosterior
33-
from botorch.sampling.samplers import SobolQMCNormalSampler
33+
from botorch.sampling.normal import SobolQMCNormalSampler
3434
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
3535
from torch import Tensor
3636
from torch.distributions import Normal
3737

3838

39-
class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction):
39+
class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction, MCSamplerMixin):
4040
r"""Abstract base class for the lower bound multi-objective entropy search
4141
acquisition functions.
4242
"""
@@ -74,6 +74,8 @@ def __init__(
7474
estimate.
7575
"""
7676
super().__init__(model=model)
77+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples]))
78+
MCSamplerMixin.__init__(self, sampler=sampler)
7779
# Batch GP models (e.g. fantasized models) are not currently supported
7880
if isinstance(model, ModelListGP):
7981
train_X = model.models[0].train_inputs[0]
@@ -119,9 +121,6 @@ def __init__(
119121
+ "."
120122
)
121123

122-
self.sampler = SobolQMCNormalSampler(
123-
num_samples=num_samples, collapse_batch_dims=True
124-
)
125124
self.set_X_pending(X_pending)
126125

127126
@abstractmethod
@@ -420,7 +419,7 @@ def _compute_monte_carlo_variables(
420419
samples.
421420
"""
422421
# `num_mc_samples x batch_shape x q x num_pareto_samples x 1 x M`
423-
samples = self.sampler(posterior)
422+
samples = self.get_posterior_samples(posterior)
424423

425424
# `num_mc_samples x batch_shape x q x num_pareto_samples`
426425
if self.model.num_outputs == 1:

botorch/acquisition/multi_objective/max_value_entropy_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _compute_monte_carlo_variables(
338338
"""
339339

340340
# `num_mc_samples x batch_shape x q x 1 x M`
341-
samples = self.sampler(posterior)
341+
samples = self.get_posterior_samples(posterior)
342342

343343
# `num_mc_samples x batch_shape x q`
344344
if self.model.num_outputs == 1:

test/acquisition/multi_objective/test_joint_entropy_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from botorch.models.gp_regression import SingleTaskGP
1818
from botorch.models.model_list_gp_regression import ModelListGP
1919
from botorch.models.transforms.outcome import Standardize
20-
from botorch.sampling.samplers import SobolQMCNormalSampler
20+
from botorch.sampling.normal import SobolQMCNormalSampler
2121
from botorch.utils.testing import BotorchTestCase
2222

2323

test/acquisition/test_joint_entropy_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from botorch.models.gp_regression import SingleTaskGP
1313
from botorch.models.model_list_gp_regression import ModelListGP
1414
from botorch.models.transforms.outcome import Standardize
15-
from botorch.sampling.samplers import SobolQMCNormalSampler
15+
from botorch.sampling.normal import SobolQMCNormalSampler
1616
from botorch.utils.testing import BotorchTestCase
1717

1818

0 commit comments

Comments
 (0)