66
77import torch
88from botorch .acquisition .analytic import AcquisitionFunction
9- from botorch .acquisition .objective import PosteriorTransform
9+ from botorch .acquisition .objective import (
10+ IdentityMCObjective ,
11+ MCAcquisitionObjective ,
12+ PosteriorTransform ,
13+ )
14+ from botorch .exceptions .errors import UnsupportedError
15+ from botorch .models .deterministic import GenericDeterministicModel
1016from botorch .models .model import Model
1117from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
12- from botorch .utils .transforms import t_batch_mode_transform
18+ from botorch .utils .transforms import is_ensemble , t_batch_mode_transform
1319from torch import Tensor
1420
1521
@@ -32,7 +38,9 @@ class PathwiseThompsonSampling(AcquisitionFunction):
3238 def __init__ (
3339 self ,
3440 model : Model ,
41+ objective : MCAcquisitionObjective | None = None ,
3542 posterior_transform : PosteriorTransform | None = None ,
43+ samples : GenericDeterministicModel | None = None ,
3644 ) -> None :
3745 r"""Single-outcome TS.
3846
@@ -41,46 +49,125 @@ def __init__(
4149 posterior_transform: A PosteriorTransform. If using a multi-output model,
4250 a PosteriorTransform that transforms the multi-output posterior into a
4351 single-output posterior is required.
52+ samples: A GenericDeterministicModel that evaluates a set of posterior
53+ sample paths.
4454 """
45- if model ._is_fully_bayesian :
46- raise NotImplementedError (
47- "PathwiseThompsonSampling is not supported for fully Bayesian models" ,
48- )
4955
5056 super ().__init__ (model = model )
51- self .batch_size : int | None = None
52-
53- def redraw (self ) -> None :
57+ self .batch_size : int | None = None if samples is None else samples .batch_shape
58+
59+ # NOTE: This conditional block is copied from MCAcquisitionFunction, we should
60+ # consider inherting from it and e.g. getting the X_pending logic as well.
61+ if objective is None and model .num_outputs != 1 :
62+ if posterior_transform is None :
63+ raise UnsupportedError (
64+ "Must specify an objective or a posterior transform when using "
65+ "a multi-output model."
66+ )
67+ elif not posterior_transform .scalarize :
68+ raise UnsupportedError (
69+ "If using a multi-output model without an objective, "
70+ "posterior_transform must scalarize the output."
71+ )
72+ if objective is None :
73+ objective = IdentityMCObjective ()
74+ self .objective = objective
75+ self .posterior_transform = posterior_transform
76+ self .samples : GenericDeterministicModel | None = samples
77+
78+ def redraw (self , batch_size : int ) -> None :
79+ sample_shape = (batch_size ,)
5480 self .samples = get_matheron_path_model (
55- model = self .model , sample_shape = torch .Size ([ self . batch_size ] )
81+ model = self .model , sample_shape = torch .Size (sample_shape )
5682 )
83+ if is_ensemble (self .model ):
84+ # the ensembling dimension is assumed to be part of the batch shape
85+ # could add a dedicated proporty to keep track of the ensembling dimension
86+ # i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP
87+ model_batch_shape = self .model .batch_shape
88+ if len (model_batch_shape ) > 1 :
89+ raise NotImplementedError (
90+ "Ensemble models with more than one ensemble dimension are not "
91+ "yet supported."
92+ )
93+ num_ensemble = model_batch_shape [0 ]
94+ self .ensemble_indices = torch .randint (
95+ 0 ,
96+ num_ensemble ,
97+ (* sample_shape , 1 , self .model .num_outputs ),
98+ )
5799
58100 @t_batch_mode_transform ()
59101 def forward (self , X : Tensor ) -> Tensor :
60102 r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61103
62104 Args:
63- X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
105+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64106
65107 Returns:
66- A `(b1 x ... bk) x [num_models for fully bayesian ]`-dim tensor of
67- evaluations on the posterior sample draws .
108+ A `batch_shape [x m ]`-dim tensor of evaluations on the posterior sample
109+ draws, where `m` is the number of outputs of the model .
68110 """
69- batch_size = X .shape [- 2 ]
70- q_dim = - 2
111+ objective_values = self ._pathwise_forward (X )
112+ # NOTE: can leverage batched L-BFGS computation instead of summing in the future
113+ # sum over batch dim and squeeze num_objectives dim (-1):
114+ acqf_vals = objective_values .sum (- 1 ) # batch_shape
115+ return acqf_vals
71116
117+ def _pathwise_forward (self , X : Tensor ) -> Tensor :
118+ batch_size = X .shape [- 2 ]
72119 # batch_shape x q x 1 x d
73120 X = X .unsqueeze (- 2 )
74- if self .batch_size is None :
121+ if self .samples is None :
75122 self .batch_size = batch_size
76- self .redraw ()
77- elif self .batch_size != batch_size :
123+ self .redraw (batch_size = batch_size )
124+
125+ if self .batch_size != batch_size :
78126 raise ValueError (
79127 BATCH_SIZE_CHANGE_ERROR .format (self .batch_size , batch_size )
80128 )
129+ # batch_shape x q [x num_ensembles] x 1 x m
130+ posterior_values = self .samples (X )
131+ # batch_shape x q [x num_ensembles] x m
132+ posterior_values = posterior_values .squeeze (- 2 )
81133
82- # posterior_values.shape post-squeeze:
83134 # batch_shape x q x m
84- posterior_values = self .samples (X ).squeeze (- 2 )
85- # sum over batch dim and squeeze num_objectives dim (-1)
86- return posterior_values .sum (q_dim ).squeeze (- 1 )
135+ posterior_values = self .select_from_ensemble_models (values = posterior_values )
136+
137+ if self .posterior_transform :
138+ posterior_values = self .posterior_transform .evaluate (posterior_values )
139+ # problem with this currently is that we could still have an `m` dimension,
140+ # ideally that would be packed into a batch dimension instead
141+ # objective removes the `m` dimension:
142+ objective_values = self .objective (posterior_values ) # batch_shape x q
143+ return objective_values
144+
145+ def select_from_ensemble_models (self , values : Tensor ):
146+ """Subselecting a value associated with a single sample in the ensemble for each
147+ element of samples that is not associated with an ensemble dimension. NOTE: uses
148+ `self.model` and `is_ensemble` to determine whether or not an ensembling
149+ dimension is present.
150+
151+ Args:
152+ values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
153+
154+ Returns:
155+ A`batch_shape x num_draws x q x m`-dim where each element was chosen
156+ independently randomly from the ensemble dimension.
157+ """
158+ if not is_ensemble (self .model ):
159+ return values
160+
161+ ensemble_dim = - 2
162+ # `ensemble_indices` are fixed so that the acquisition function becomes
163+ # deterministic for the same input and can be optimized with LBFGS.
164+ # ensemble indices have shape num_paths x 1 x m
165+ self .ensemble_indices = self .ensemble_indices .to (device = values .device )
166+ index = self .ensemble_indices
167+ input_batch_shape = values .shape [:- 3 ]
168+ index = index .expand (* input_batch_shape , * index .shape )
169+ # samples is batch_shape x q x num_ensemble x m
170+ values_wo_ensemble = torch .gather (values , dim = ensemble_dim , index = index )
171+ return values_wo_ensemble .squeeze (
172+ ensemble_dim
173+ ) # removing the ensemble dimension
0 commit comments