66
77import torch
88from botorch .acquisition .analytic import AcquisitionFunction
9- from botorch .acquisition .objective import PosteriorTransform
9+ from botorch .acquisition .objective import (
10+ IdentityMCObjective ,
11+ MCAcquisitionObjective ,
12+ PosteriorTransform ,
13+ )
14+ from botorch .exceptions .errors import UnsupportedError
15+ from botorch .models .deterministic import GenericDeterministicModel
1016from botorch .models .model import Model
1117from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
12- from botorch .utils .transforms import t_batch_mode_transform
18+ from botorch .utils .transforms import is_ensemble , t_batch_mode_transform
1319from torch import Tensor
1420
1521
@@ -32,55 +38,152 @@ class PathwiseThompsonSampling(AcquisitionFunction):
3238 def __init__ (
3339 self ,
3440 model : Model ,
41+ objective : MCAcquisitionObjective | None = None ,
3542 posterior_transform : PosteriorTransform | None = None ,
3643 ) -> None :
3744 r"""Single-outcome TS.
3845
46+ If using a multi-output `model`, the acquisition function requires either an
47+ `objective` or a `posterior_transform` that transforms the multi-output
48+ posterior samples to single-output posterior samples.
49+ objective: An MCAcquisitionObjective. Defaults to `IdentityMCObjective`.
50+
3951 Args:
4052 model: A fitted GP model.
41- posterior_transform: A PosteriorTransform. If using a multi-output model,
42- a PosteriorTransform that transforms the multi-output posterior into a
43- single-output posterior is required .
53+ objective: The MCAcquisitionObjective under which the samples are
54+ evaluated. Defaults to `IdentityMCObjective()`.
55+ posterior_transform: An optional PosteriorTransform .
4456 """
45- if model ._is_fully_bayesian :
46- raise NotImplementedError (
47- "PathwiseThompsonSampling is not supported for fully Bayesian models" ,
48- )
4957
5058 super ().__init__ (model = model )
5159 self .batch_size : int | None = None
52-
53- def redraw (self ) -> None :
60+ self .samples : GenericDeterministicModel | None = None
61+ self .ensemble_indices : Tensor | None = None
62+
63+ # NOTE: This conditional block is copied from MCAcquisitionFunction, we should
64+ # consider inherting from it and e.g. getting the X_pending logic as well.
65+ if objective is None and model .num_outputs != 1 :
66+ if posterior_transform is None :
67+ raise UnsupportedError (
68+ "Must specify an objective or a posterior transform when using "
69+ "a multi-output model."
70+ )
71+ elif not posterior_transform .scalarize :
72+ raise UnsupportedError (
73+ "If using a multi-output model without an objective, "
74+ "posterior_transform must scalarize the output."
75+ )
76+ if objective is None :
77+ objective = IdentityMCObjective ()
78+ self .objective = objective
79+ self .posterior_transform = posterior_transform
80+
81+ def redraw (self , batch_size : int ) -> None :
82+ sample_shape = (batch_size ,)
5483 self .samples = get_matheron_path_model (
55- model = self .model , sample_shape = torch .Size ([ self . batch_size ] )
84+ model = self .model , sample_shape = torch .Size (sample_shape )
5685 )
86+ if is_ensemble (self .model ):
87+ # the ensembling dimension is assumed to be part of the batch shape
88+ model_batch_shape = self .model .batch_shape
89+ if len (model_batch_shape ) > 1 :
90+ raise NotImplementedError (
91+ "Ensemble models with more than one ensemble dimension are not "
92+ "yet supported."
93+ )
94+ num_ensemble = model_batch_shape [0 ]
95+ # ensemble_indices is cached here to ensure that the acquisition function
96+ # becomes deterministic for the same input and can be optimized with LBFGS.
97+ # ensemble_indices is used in select_from_ensemble_models.
98+ self .ensemble_indices = torch .randint (
99+ 0 ,
100+ num_ensemble ,
101+ (* sample_shape , 1 , self .model .num_outputs ),
102+ )
57103
58104 @t_batch_mode_transform ()
59105 def forward (self , X : Tensor ) -> Tensor :
60106 r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61107
62108 Args:
63- X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
109+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64110
65111 Returns:
66- A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67- evaluations on the posterior sample draws .
112+ A `batch_shape`-dim tensor of evaluations on the posterior sample draws,
113+ where the samples are summed over the q-batch dimension .
68114 """
69- batch_size = X .shape [- 2 ]
70- q_dim = - 2
115+ objective_values = self ._pathwise_forward (X ) # batch_shape x q
116+ # NOTE: The current implementation sums over the q-batch dimension, which means
117+ # that we are optimizing the sum of independent Thompson samples. In the future,
118+ # we can leverage *batched* L-BFGS optimization, rather than summing over the q
119+ # dimension, which will guarantee descent steps for all members of the batch
120+ # through batch-member-specific learning rate selection.
121+ return objective_values .sum (- 1 ) # batch_shape
71122
123+ def _pathwise_forward (self , X : Tensor ) -> Tensor :
124+ """Evaluate the pathwise posterior sample draws on the candidate set X.
125+
126+ Args:
127+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
128+
129+ Returns:
130+ A `batch_shape x q`-dim tensor of evaluations on the posterior sample draws.
131+ """
132+ batch_size = X .shape [- 2 ]
72133 # batch_shape x q x 1 x d
73134 X = X .unsqueeze (- 2 )
74- if self .batch_size is None :
135+ if self .samples is None :
75136 self .batch_size = batch_size
76- self .redraw ()
77- elif self .batch_size != batch_size :
137+ self .redraw (batch_size = batch_size )
138+
139+ if self .batch_size != batch_size :
78140 raise ValueError (
79141 BATCH_SIZE_CHANGE_ERROR .format (self .batch_size , batch_size )
80142 )
143+ # batch_shape x q [x num_ensembles] x 1 x m
144+ posterior_values = self .samples (X )
145+ # batch_shape x q [x num_ensembles] x m
146+ posterior_values = posterior_values .squeeze (- 2 )
81147
82- # posterior_values.shape post-squeeze:
83148 # batch_shape x q x m
84- posterior_values = self .samples (X ).squeeze (- 2 )
85- # sum over batch dim and squeeze num_objectives dim (-1)
86- return posterior_values .sum (q_dim ).squeeze (- 1 )
149+ posterior_values = self .select_from_ensemble_models (values = posterior_values )
150+
151+ if self .posterior_transform :
152+ posterior_values = self .posterior_transform .evaluate (posterior_values )
153+ # objective removes the `m` dimension
154+ objective_values = self .objective (posterior_values ) # batch_shape x q
155+ return objective_values
156+
157+ def select_from_ensemble_models (self , values : Tensor ):
158+ """Subselecting a value associated with a single sample in the ensemble for each
159+ element of samples that is not associated with an ensemble dimension.
160+
161+ NOTE: 1) uses `self.model` and `is_ensemble` to determine whether or not an
162+ ensembling dimension is present. 2) uses `self.ensemble_indices` to select the
163+ value associated with a single sample in the ensemble. `ensemble_indices`
164+ contains uniformly randomly sample indices for each element of the ensemble, but
165+ is cached to make the evaluation of the acquisition function deterministic.
166+
167+ Args:
168+ values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
169+
170+ Returns:
171+ A`batch_shape x num_draws x q x m`-dim where each element is contains a
172+ single sample from the ensemble, selected with `self.ensemble_indices`.
173+ """
174+ if not is_ensemble (self .model ):
175+ return values
176+
177+ ensemble_dim = - 2
178+ # `ensemble_indices` are fixed so that the acquisition function becomes
179+ # deterministic for the same input and can be optimized with LBFGS.
180+ # ensemble indices have shape num_paths x 1 x m
181+ self .ensemble_indices = self .ensemble_indices .to (device = values .device )
182+ index = self .ensemble_indices
183+ input_batch_shape = values .shape [:- 3 ]
184+ index = index .expand (* input_batch_shape , * index .shape )
185+ # samples is batch_shape x q x num_ensemble x m
186+ values_wo_ensemble = torch .gather (values , dim = ensemble_dim , index = index )
187+ return values_wo_ensemble .squeeze (
188+ ensemble_dim
189+ ) # removing the ensemble dimension
0 commit comments