66
77import torch
88from botorch .acquisition .analytic import AcquisitionFunction
9- from botorch .acquisition .objective import PosteriorTransform
9+ from botorch .acquisition .objective import (
10+ IdentityMCObjective ,
11+ MCAcquisitionObjective ,
12+ PosteriorTransform ,
13+ )
14+ from botorch .exceptions .errors import UnsupportedError
15+ from botorch .models .deterministic import GenericDeterministicModel
1016from botorch .models .model import Model
1117from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
12- from botorch .utils .transforms import t_batch_mode_transform
18+ from botorch .utils .transforms import is_ensemble , t_batch_mode_transform
1319from torch import Tensor
1420
1521
@@ -32,55 +38,151 @@ class PathwiseThompsonSampling(AcquisitionFunction):
3238 def __init__ (
3339 self ,
3440 model : Model ,
41+ objective : MCAcquisitionObjective | None = None ,
3542 posterior_transform : PosteriorTransform | None = None ,
3643 ) -> None :
3744 r"""Single-outcome TS.
3845
46+ If using a multi-output `model`, the acquisition function requires either an
47+ `objective` or a `posterior_transform` that transforms the multi-output
48+ posterior samples to single-output posterior samples.
49+
3950 Args:
4051 model: A fitted GP model.
41- posterior_transform: A PosteriorTransform. If using a multi-output model,
42- a PosteriorTransform that transforms the multi-output posterior into a
43- single-output posterior is required .
52+ objective: The MCAcquisitionObjective under which the samples are
53+ evaluated. Defaults to `IdentityMCObjective()`.
54+ posterior_transform: An optional PosteriorTransform .
4455 """
45- if model ._is_fully_bayesian :
46- raise NotImplementedError (
47- "PathwiseThompsonSampling is not supported for fully Bayesian models" ,
48- )
4956
5057 super ().__init__ (model = model )
5158 self .batch_size : int | None = None
52-
53- def redraw (self ) -> None :
59+ self .samples : GenericDeterministicModel | None = None
60+ self .ensemble_indices : Tensor | None = None
61+
62+ # NOTE: This conditional block is copied from MCAcquisitionFunction, we should
63+ # consider inherting from it and e.g. getting the X_pending logic as well.
64+ if objective is None and model .num_outputs != 1 :
65+ if posterior_transform is None :
66+ raise UnsupportedError (
67+ "Must specify an objective or a posterior transform when using "
68+ "a multi-output model."
69+ )
70+ elif not posterior_transform .scalarize :
71+ raise UnsupportedError (
72+ "If using a multi-output model without an objective, "
73+ "posterior_transform must scalarize the output."
74+ )
75+ if objective is None :
76+ objective = IdentityMCObjective ()
77+ self .objective = objective
78+ self .posterior_transform = posterior_transform
79+
80+ def redraw (self , batch_size : int ) -> None :
81+ sample_shape = (batch_size ,)
5482 self .samples = get_matheron_path_model (
55- model = self .model , sample_shape = torch .Size ([ self . batch_size ] )
83+ model = self .model , sample_shape = torch .Size (sample_shape )
5684 )
85+ if is_ensemble (self .model ):
86+ # the ensembling dimension is assumed to be part of the batch shape
87+ model_batch_shape = self .model .batch_shape
88+ if len (model_batch_shape ) > 1 :
89+ raise NotImplementedError (
90+ "Ensemble models with more than one ensemble dimension are not "
91+ "yet supported."
92+ )
93+ num_ensemble = model_batch_shape [0 ]
94+ # ensemble_indices is cached here to ensure that the acquisition function
95+ # becomes deterministic for the same input and can be optimized with LBFGS.
96+ # ensemble_indices is used in select_from_ensemble_models.
97+ self .ensemble_indices = torch .randint (
98+ 0 ,
99+ num_ensemble ,
100+ (* sample_shape , 1 , self .model .num_outputs ),
101+ )
57102
58103 @t_batch_mode_transform ()
59104 def forward (self , X : Tensor ) -> Tensor :
60105 r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61106
62107 Args:
63- X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
108+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64109
65110 Returns:
66- A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67- evaluations on the posterior sample draws .
111+ A `batch_shape`-dim tensor of evaluations on the posterior sample draws,
112+ where the samples are summed over the q-batch dimension .
68113 """
69- batch_size = X .shape [- 2 ]
70- q_dim = - 2
114+ objective_values = self ._pathwise_forward (X ) # batch_shape x q
115+ # NOTE: The current implementation sums over the q-batch dimension, which means
116+ # that we are optimizing the sum of independent Thompson samples. In the future,
117+ # we can leverage *batched* L-BFGS optimization, rather than summing over the q
118+ # dimension, which will guarantee descent steps for all members of the batch
119+ # through batch-member-specific learning rate selection.
120+ return objective_values .sum (- 1 ) # batch_shape
71121
122+ def _pathwise_forward (self , X : Tensor ) -> Tensor :
123+ """Evaluate the pathwise posterior sample draws on the candidate set X.
124+
125+ Args:
126+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
127+
128+ Returns:
129+ A `batch_shape x q`-dim tensor of evaluations on the posterior sample draws.
130+ """
131+ batch_size = X .shape [- 2 ]
72132 # batch_shape x q x 1 x d
73133 X = X .unsqueeze (- 2 )
74- if self .batch_size is None :
134+ if self .samples is None :
75135 self .batch_size = batch_size
76- self .redraw ()
77- elif self .batch_size != batch_size :
136+ self .redraw (batch_size = batch_size )
137+
138+ if self .batch_size != batch_size :
78139 raise ValueError (
79140 BATCH_SIZE_CHANGE_ERROR .format (self .batch_size , batch_size )
80141 )
142+ # batch_shape x q [x num_ensembles] x 1 x m
143+ posterior_values = self .samples (X )
144+ # batch_shape x q [x num_ensembles] x m
145+ posterior_values = posterior_values .squeeze (- 2 )
81146
82- # posterior_values.shape post-squeeze:
83147 # batch_shape x q x m
84- posterior_values = self .samples (X ).squeeze (- 2 )
85- # sum over batch dim and squeeze num_objectives dim (-1)
86- return posterior_values .sum (q_dim ).squeeze (- 1 )
148+ posterior_values = self .select_from_ensemble_models (values = posterior_values )
149+
150+ if self .posterior_transform :
151+ posterior_values = self .posterior_transform .evaluate (posterior_values )
152+ # objective removes the `m` dimension
153+ objective_values = self .objective (posterior_values ) # batch_shape x q
154+ return objective_values
155+
156+ def select_from_ensemble_models (self , values : Tensor ):
157+ """Subselecting a value associated with a single sample in the ensemble for each
158+ element of samples that is not associated with an ensemble dimension.
159+
160+ NOTE: 1) uses `self.model` and `is_ensemble` to determine whether or not an
161+ ensembling dimension is present. 2) uses `self.ensemble_indices` to select the
162+ value associated with a single sample in the ensemble. `ensemble_indices`
163+ contains uniformly randomly sample indices for each element of the ensemble, but
164+ is cached to make the evaluation of the acquisition function deterministic.
165+
166+ Args:
167+ values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
168+
169+ Returns:
170+ A`batch_shape x num_draws x q x m`-dim where each element is contains a
171+ single sample from the ensemble, selected with `self.ensemble_indices`.
172+ """
173+ if not is_ensemble (self .model ):
174+ return values
175+
176+ ensemble_dim = - 2
177+ # `ensemble_indices` are fixed so that the acquisition function becomes
178+ # deterministic for the same input and can be optimized with LBFGS.
179+ # ensemble indices have shape num_paths x 1 x m
180+ self .ensemble_indices = self .ensemble_indices .to (device = values .device )
181+ index = self .ensemble_indices
182+ input_batch_shape = values .shape [:- 3 ]
183+ index = index .expand (* input_batch_shape , * index .shape )
184+ # samples is batch_shape x q x num_ensemble x m
185+ values_wo_ensemble = torch .gather (values , dim = ensemble_dim , index = index )
186+ return values_wo_ensemble .squeeze (
187+ ensemble_dim
188+ ) # removing the ensemble dimension
0 commit comments