2626 download_model ,
2727 ModelPaths ,
2828)
29- from botorch_community .posteriors .riemann import BoundedRiemannPosterior
29+ from botorch_community .posteriors .riemann import (
30+ BoundedRiemannPosterior ,
31+ MultivariateRiemannPosterior ,
32+ )
3033from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
3134from pfns .train import MainConfig # @manual=//pytorch/PFNs:PFNs
3235from torch import Tensor
@@ -144,7 +147,7 @@ def posterior(
144147 Args:
145148 X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
146149 feature space.
147- output_indices: **Currenlty not supported for PFNModel.**
150+ output_indices: **Currently not supported for PFNModel.**
148151 observation_noise: **Currently not supported for PFNModel**.
149152 posterior_transform: **Currently not supported for PFNModel**.
150153
@@ -165,10 +168,7 @@ def posterior(
165168 if posterior_transform is not None :
166169 raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
167170
168- orig_X_shape = X .shape # X has shape b? x q? x d
169- X = self .prepare_X (X ) # shape (b, q, d)
170- train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
171- train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
171+ X , train_X , train_Y , orig_X_shape = self ._prepare_data (X )
172172
173173 probabilities = self .pfn_predict (
174174 X = X , train_X = train_X , train_Y = train_Y
@@ -177,27 +177,34 @@ def posterior(
177177 * orig_X_shape [:- 1 ], - 1
178178 ) # (b?, q?, num_buckets)
179179
180- # Get posterior with the right dtype
181- borders = self .pfn .criterion .borders .to (X .dtype )
182180 return BoundedRiemannPosterior (
183- borders = borders ,
181+ borders = self . borders ,
184182 probabilities = probabilities ,
185183 )
186184
187- def prepare_X (self , X : Tensor ) -> Tensor :
185+ def _prepare_data (self , X : Tensor ) -> tuple [Tensor , Tensor , Tensor , torch .Size ]:
186+ orig_X_shape = X .shape # X has shape b? x q? x d
188187 if len (X .shape ) > 3 :
189188 raise UnsupportedError (f"X must be at most 3-d, got { X .shape } ." )
190189 while len (X .shape ) < 3 :
191190 X = X .unsqueeze (0 )
192191
193192 X = self .transform_inputs (X ) # shape (b , q, d)
194- return X
193+
194+ train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
195+ train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
196+ return X , train_X , train_Y , orig_X_shape
195197
196198 def pfn_predict (self , X : Tensor , train_X : Tensor , train_Y : Tensor ) -> Tensor :
197199 """
198- X has shape (b, q, d)
199- train_X has shape (b, n, d)
200- train_Y has shape (b, n, 1)
200+ Make a prediction using the PFN model on X given training data.
201+
202+ Args:
203+ X: has shape (b, q, d)
204+ train_X: has shape (b, n, d)
205+ train_Y: has shape (b, n, 1)
206+
207+ Returns: probabilities (b, q, num_buckets) for Riemann posterior.
201208 """
202209 if not self .batch_first :
203210 X = X .transpose (0 , 1 ) # shape (q, b, d)
@@ -216,3 +223,229 @@ def pfn_predict(self, X: Tensor, train_X: Tensor, train_Y: Tensor) -> Tensor:
216223
217224 probabilities = logits .softmax (dim = - 1 ) # shape (b, q, num_buckets)
218225 return probabilities
226+
227+ @property
228+ def borders (self ):
229+ return self .pfn .criterion .borders .to (self .train_X .dtype )
230+
231+
232+ class MultivariatePFNModel (PFNModel ):
233+ """A multivariate PFN model that returns a joint posterior over q batch inputs.
234+
235+ For this to work correctly it is necessary that the underlying model return a
236+ posterior for the latent f, not the noisy observed y.
237+ """
238+
239+ def posterior (
240+ self ,
241+ X : Tensor ,
242+ output_indices : Optional [list [int ]] = None ,
243+ observation_noise : Union [bool , Tensor ] = False ,
244+ posterior_transform : Optional [PosteriorTransform ] = None ,
245+ ) -> Union [BoundedRiemannPosterior , MultivariateRiemannPosterior ]:
246+ """Computes the posterior over model outputs at the provided points.
247+
248+ Will produce a MultivariateRiemannPosterior that fits a joint structure
249+ over the q batch dimension of X. This will require an additional forward
250+ pass through the PFN model, and some approximation.
251+
252+ If q = 1 or there is no q dimension, will return a BoundedRiemannPosterior
253+ and behave the same as PFNModel.
254+
255+ Args:
256+ X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
257+ feature space.
258+ output_indices: **Currently not supported for PFNModel.**
259+ observation_noise: **Currently not supported for PFNModel**.
260+ posterior_transform: **Currently not supported for PFNModel**.
261+
262+ Returns:
263+ A posterior representing a batch of b? x q? distributions.
264+ """
265+ marginals = super ().posterior (
266+ X = X ,
267+ output_indices = output_indices ,
268+ observation_noise = observation_noise ,
269+ posterior_transform = posterior_transform ,
270+ )
271+ if len (X .shape ) == 1 or X .shape [- 2 ] == 1 :
272+ # No q dimension, or q=1
273+ return marginals
274+ X , train_X , train_Y , orig_X_shape = self ._prepare_data (X )
275+ # Estimate correlation structure, making another forward pass.
276+ R = self .estimate_correlations (
277+ X = X ,
278+ train_X = train_X ,
279+ train_Y = train_Y ,
280+ marginals = marginals ,
281+ ) # (b, q, q)
282+ R = R .view (* orig_X_shape [:- 2 ], X .shape [- 2 ], X .shape [- 2 ]) # (b?, q, q)
283+ return MultivariateRiemannPosterior (
284+ borders = self .borders ,
285+ probabilities = marginals .probabilities ,
286+ correlation_matrix = R ,
287+ )
288+
289+ def estimate_correlations (
290+ self ,
291+ X : Tensor ,
292+ train_X : Tensor ,
293+ train_Y : Tensor ,
294+ marginals : BoundedRiemannPosterior ,
295+ ) -> Tensor :
296+ """
297+ Estimate a correlation matrix R across the q batch of points in X.
298+ Will do a forward pass through the PFN model with batch size O(q^2).
299+
300+ For every x_q in [x_1, ..., x_Q]:
301+ 1. Add x_q to train_X, with y_q the 90th percentile value for f(x_q)
302+ 2. Evaluate p(f(x_i)) for all points.
303+
304+ Uses bivariate normal conditioning formulae, and so will be approximate.
305+
306+ Args:
307+ X: evaluation point, shape (b, q, d)
308+ train_X: Training X, shape (b, n, d)
309+ train_Y: Training Y, shape (b, n, 1)
310+ marginals: A posterior object with marginal posteriors for f(X), but no
311+ correlation structure yet added. posterior.probabilities has
312+ shape (b?, q, num_buckets).
313+
314+ Returns: A (b, q, q) correlation matrix
315+ """
316+ # Compute conditional distributions with a forward pass
317+ cond_mean , cond_val = self ._compute_conditional_means (
318+ X = X ,
319+ train_X = train_X ,
320+ train_Y = train_Y ,
321+ marginals = marginals ,
322+ )
323+ # Get marginal moments
324+ var = marginals .variance .squeeze (- 1 ) # (b?, q)
325+ mean = marginals .mean .squeeze (- 1 ) # (b?, q)
326+ if len (var .shape ) == 1 :
327+ var = var .unsqueeze (0 ) # (b, q)
328+ mean = mean .unsqueeze (0 ) # (b, q)
329+ # Estimate covariances from conditional distributions
330+ cov = self ._estimate_covariances (
331+ cond_mean = cond_mean ,
332+ cond_val = cond_val ,
333+ mean = mean ,
334+ var = var ,
335+ )
336+ # Convert to correlation matrix
337+ S = 1 / torch .sqrt (torch .diagonal (cov , dim1 = - 2 , dim2 = - 1 )) # (b, q)
338+ S = S .unsqueeze (- 1 ).expand (cov .shape ) # (b, q, q)
339+ R = S * cov * S .transpose (- 1 , - 2 ) # (b, q, q)
340+ return R
341+
342+ def _compute_conditional_means (
343+ self ,
344+ X : Tensor ,
345+ train_X : Tensor ,
346+ train_Y : Tensor ,
347+ marginals : BoundedRiemannPosterior ,
348+ ) -> tuple [Tensor , Tensor ]:
349+ """
350+ Compute conditional means between pairs of points in X.
351+
352+ Conditioning is done with an additional forward pass through the model. The
353+ returned conditional mean will be of shape (b, q, q), with entry [b, i, j] the
354+ conditional mean of j given i set to the conditioning value.
355+
356+ Args:
357+ X: evaluation point, shape (b, q, d)
358+ train_X: Training X, shape (b, n, d)
359+ train_Y: Training Y, shape (b, n, 1)
360+ marginals: A posterior object with marginal posteriors for f(X), but no
361+ correlation structure yet added. posterior.probabilities has
362+ shape (b?, q, num_buckets).
363+
364+ Returns: conditional means (b, q, q), and values used for conditioning (b, q).
365+ """
366+ b , q , d = X .shape
367+ n = train_X .shape [- 2 ]
368+ post_shape = marginals .probabilities .shape [:- 1 ]
369+ # Find the 90th percentile of each eval point.
370+ cond_val = marginals .icdf (
371+ torch .full (post_shape , 0.9 , device = X .device , dtype = X .dtype ).unsqueeze (0 )
372+ ) # (1, b?, q, 1)
373+ cond_val = cond_val .view (b , q ) # (b, q)
374+ # Construct conditional training data.
375+ # train_X will have shape (b, q, n+1, d), to have a conditional observation
376+ # for each point. train_Y will have shape (b, q, n+1, 1).
377+ train_X = train_X .unsqueeze (1 ).expand (b , q , n , d )
378+ cond_X = X .unsqueeze (- 2 ) # (b, q, 1, d)
379+ train_X = torch .cat ((train_X , cond_X ), dim = - 2 ) # (b, q, n+1, d)
380+ train_Y = train_Y .unsqueeze (1 ).expand (b , q , n , 1 )
381+ cond_Y = cond_val .unsqueeze (- 1 ).unsqueeze (- 1 ) # (b, q, 1, 1)
382+ train_Y = torch .cat ((train_Y , cond_Y ), dim = - 2 ) # (b, q, n+1, 1)
383+ # Construct eval points
384+ eval_X = X .unsqueeze (1 ).expand (b , q , q , d )
385+ # Squeeze everything into necessary 2 batch dims, and do PFN forward pass
386+ cond_probabilities = self .pfn_predict (
387+ X = eval_X .reshape (b * q , q , d ),
388+ train_X = train_X .reshape (b * q , n + 1 , d ),
389+ train_Y = train_Y .reshape (b * q , n + 1 , 1 ),
390+ ) # (b * q, q, num_buckets)
391+ # Object for conditional posteriors
392+ cond_posterior = BoundedRiemannPosterior (
393+ borders = self .borders ,
394+ probabilities = cond_probabilities ,
395+ )
396+ # Get conditional means
397+ cond_mean = cond_posterior .mean .squeeze (- 1 ) # (b * q, q)
398+ cond_mean = cond_mean .unsqueeze (0 ).view (b , q , q )
399+ return cond_mean , cond_val
400+
401+ def _estimate_covariances (
402+ self ,
403+ cond_mean : Tensor ,
404+ cond_val : Tensor ,
405+ mean : Tensor ,
406+ var : Tensor ,
407+ ) -> Tensor :
408+ """
409+ Estimate covariances from conditional distributions.
410+
411+ Part one: Compute noise variance implied by conditional distributions
412+ E[f_j | y_j=y] = E[f_j] + var[f_j]/(var[f_j] + noise_var) * (y - E[f_j])
413+ Let Z_jj = (E[f_j | y_j=y] - E[f_j]) / (y - E[f_j]).
414+ Note that Z is in (0, 1].
415+ Then, noise_var_j = var[f_j](1/Z_jj - 1).
416+
417+ Part two: Compute covariances for all pairs
418+ E[f_j|y_i=y] = E[f_j]+cov[f_j, f_i]/(var[f_i] + noise_var_i) * (y - E[f_i])
419+ Let Z_ij = (E[f_j | y_i=y] - E[f_j]) / (y - E[f_i]).
420+ Then, cov[f_j, f_i] = Z * (var[f_i] + noise_var)
421+
422+ Args:
423+ cond_mean: (b, q, q) means of dim -1 conditioned on dim -2
424+ cond_val: (b, q) conditioned y value.
425+ var: (b, q) marginal variances
426+ mean: (b, q) marginal means
427+
428+ Returns: Covariance matrix
429+ """
430+ Z = (cond_mean - mean .unsqueeze (- 2 ).expand (cond_mean .shape )) / (
431+ cond_val - mean
432+ ).unsqueeze (- 1 ) # (b, q, q)
433+ # Z[i, j] is for j cond. on i
434+ noise_var = torch .clamp (
435+ var * (1 / torch .diagonal (Z , dim1 = - 2 , dim2 = - 1 ) - 1 ), min = 1e-8
436+ ) # (b, q)
437+ cov = Z * (var + noise_var ).unsqueeze (- 1 ) # (b, q, q)
438+ # Symmetrize
439+ cov = 0.5 * (cov + cov .transpose (- 1 , - 2 ))
440+ cov = self ._map_psd (cov )
441+ return cov
442+
443+ def _map_psd (self , A ):
444+ """
445+ Map A (assumed symmetric) to the nearest PSD matrix.
446+ """
447+ if torch .linalg .eigvals (A ).real .min () < 0 :
448+ L , Q = torch .linalg .eigh (A )
449+ L = torch .clamp (L , min = 1e-6 )
450+ A = Q @ torch .diag_embed (L ) @ Q .transpose (- 1 , - 2 )
451+ return A
0 commit comments