Skip to content

Commit 09502f9

Browse files
blethammeta-codesync[bot]
authored andcommitted
Copula multivariate posterior for PFN (#3045)
Summary: Pull Request resolved: #3045 Implements a PFN model subclass that produces joint samples over the q-batch. Reviewed By: SamuelGabriel Differential Revision: D83996991 fbshipit-source-id: 2ebad3ea396c9dc6406242b5592adf3f631fca0c
1 parent f122efc commit 09502f9

File tree

4 files changed

+484
-16
lines changed

4 files changed

+484
-16
lines changed

botorch_community/models/prior_fitted_network.py

Lines changed: 247 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
download_model,
2727
ModelPaths,
2828
)
29-
from botorch_community.posteriors.riemann import BoundedRiemannPosterior
29+
from botorch_community.posteriors.riemann import (
30+
BoundedRiemannPosterior,
31+
MultivariateRiemannPosterior,
32+
)
3033
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
3134
from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs
3235
from torch import Tensor
@@ -144,7 +147,7 @@ def posterior(
144147
Args:
145148
X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
146149
feature space.
147-
output_indices: **Currenlty not supported for PFNModel.**
150+
output_indices: **Currently not supported for PFNModel.**
148151
observation_noise: **Currently not supported for PFNModel**.
149152
posterior_transform: **Currently not supported for PFNModel**.
150153
@@ -165,10 +168,7 @@ def posterior(
165168
if posterior_transform is not None:
166169
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
167170

168-
orig_X_shape = X.shape # X has shape b? x q? x d
169-
X = self.prepare_X(X) # shape (b, q, d)
170-
train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d)
171-
train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1)
171+
X, train_X, train_Y, orig_X_shape = self._prepare_data(X)
172172

173173
probabilities = self.pfn_predict(
174174
X=X, train_X=train_X, train_Y=train_Y
@@ -177,27 +177,34 @@ def posterior(
177177
*orig_X_shape[:-1], -1
178178
) # (b?, q?, num_buckets)
179179

180-
# Get posterior with the right dtype
181-
borders = self.pfn.criterion.borders.to(X.dtype)
182180
return BoundedRiemannPosterior(
183-
borders=borders,
181+
borders=self.borders,
184182
probabilities=probabilities,
185183
)
186184

187-
def prepare_X(self, X: Tensor) -> Tensor:
185+
def _prepare_data(self, X: Tensor) -> tuple[Tensor, Tensor, Tensor, torch.Size]:
186+
orig_X_shape = X.shape # X has shape b? x q? x d
188187
if len(X.shape) > 3:
189188
raise UnsupportedError(f"X must be at most 3-d, got {X.shape}.")
190189
while len(X.shape) < 3:
191190
X = X.unsqueeze(0)
192191

193192
X = self.transform_inputs(X) # shape (b , q, d)
194-
return X
193+
194+
train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d)
195+
train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1)
196+
return X, train_X, train_Y, orig_X_shape
195197

196198
def pfn_predict(self, X: Tensor, train_X: Tensor, train_Y: Tensor) -> Tensor:
197199
"""
198-
X has shape (b, q, d)
199-
train_X has shape (b, n, d)
200-
train_Y has shape (b, n, 1)
200+
Make a prediction using the PFN model on X given training data.
201+
202+
Args:
203+
X: has shape (b, q, d)
204+
train_X: has shape (b, n, d)
205+
train_Y: has shape (b, n, 1)
206+
207+
Returns: probabilities (b, q, num_buckets) for Riemann posterior.
201208
"""
202209
if not self.batch_first:
203210
X = X.transpose(0, 1) # shape (q, b, d)
@@ -216,3 +223,229 @@ def pfn_predict(self, X: Tensor, train_X: Tensor, train_Y: Tensor) -> Tensor:
216223

217224
probabilities = logits.softmax(dim=-1) # shape (b, q, num_buckets)
218225
return probabilities
226+
227+
@property
228+
def borders(self):
229+
return self.pfn.criterion.borders.to(self.train_X.dtype)
230+
231+
232+
class MultivariatePFNModel(PFNModel):
233+
"""A multivariate PFN model that returns a joint posterior over q batch inputs.
234+
235+
For this to work correctly it is necessary that the underlying model return a
236+
posterior for the latent f, not the noisy observed y.
237+
"""
238+
239+
def posterior(
240+
self,
241+
X: Tensor,
242+
output_indices: Optional[list[int]] = None,
243+
observation_noise: Union[bool, Tensor] = False,
244+
posterior_transform: Optional[PosteriorTransform] = None,
245+
) -> Union[BoundedRiemannPosterior, MultivariateRiemannPosterior]:
246+
"""Computes the posterior over model outputs at the provided points.
247+
248+
Will produce a MultivariateRiemannPosterior that fits a joint structure
249+
over the q batch dimension of X. This will require an additional forward
250+
pass through the PFN model, and some approximation.
251+
252+
If q = 1 or there is no q dimension, will return a BoundedRiemannPosterior
253+
and behave the same as PFNModel.
254+
255+
Args:
256+
X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
257+
feature space.
258+
output_indices: **Currently not supported for PFNModel.**
259+
observation_noise: **Currently not supported for PFNModel**.
260+
posterior_transform: **Currently not supported for PFNModel**.
261+
262+
Returns:
263+
A posterior representing a batch of b? x q? distributions.
264+
"""
265+
marginals = super().posterior(
266+
X=X,
267+
output_indices=output_indices,
268+
observation_noise=observation_noise,
269+
posterior_transform=posterior_transform,
270+
)
271+
if len(X.shape) == 1 or X.shape[-2] == 1:
272+
# No q dimension, or q=1
273+
return marginals
274+
X, train_X, train_Y, orig_X_shape = self._prepare_data(X)
275+
# Estimate correlation structure, making another forward pass.
276+
R = self.estimate_correlations(
277+
X=X,
278+
train_X=train_X,
279+
train_Y=train_Y,
280+
marginals=marginals,
281+
) # (b, q, q)
282+
R = R.view(*orig_X_shape[:-2], X.shape[-2], X.shape[-2]) # (b?, q, q)
283+
return MultivariateRiemannPosterior(
284+
borders=self.borders,
285+
probabilities=marginals.probabilities,
286+
correlation_matrix=R,
287+
)
288+
289+
def estimate_correlations(
290+
self,
291+
X: Tensor,
292+
train_X: Tensor,
293+
train_Y: Tensor,
294+
marginals: BoundedRiemannPosterior,
295+
) -> Tensor:
296+
"""
297+
Estimate a correlation matrix R across the q batch of points in X.
298+
Will do a forward pass through the PFN model with batch size O(q^2).
299+
300+
For every x_q in [x_1, ..., x_Q]:
301+
1. Add x_q to train_X, with y_q the 90th percentile value for f(x_q)
302+
2. Evaluate p(f(x_i)) for all points.
303+
304+
Uses bivariate normal conditioning formulae, and so will be approximate.
305+
306+
Args:
307+
X: evaluation point, shape (b, q, d)
308+
train_X: Training X, shape (b, n, d)
309+
train_Y: Training Y, shape (b, n, 1)
310+
marginals: A posterior object with marginal posteriors for f(X), but no
311+
correlation structure yet added. posterior.probabilities has
312+
shape (b?, q, num_buckets).
313+
314+
Returns: A (b, q, q) correlation matrix
315+
"""
316+
# Compute conditional distributions with a forward pass
317+
cond_mean, cond_val = self._compute_conditional_means(
318+
X=X,
319+
train_X=train_X,
320+
train_Y=train_Y,
321+
marginals=marginals,
322+
)
323+
# Get marginal moments
324+
var = marginals.variance.squeeze(-1) # (b?, q)
325+
mean = marginals.mean.squeeze(-1) # (b?, q)
326+
if len(var.shape) == 1:
327+
var = var.unsqueeze(0) # (b, q)
328+
mean = mean.unsqueeze(0) # (b, q)
329+
# Estimate covariances from conditional distributions
330+
cov = self._estimate_covariances(
331+
cond_mean=cond_mean,
332+
cond_val=cond_val,
333+
mean=mean,
334+
var=var,
335+
)
336+
# Convert to correlation matrix
337+
S = 1 / torch.sqrt(torch.diagonal(cov, dim1=-2, dim2=-1)) # (b, q)
338+
S = S.unsqueeze(-1).expand(cov.shape) # (b, q, q)
339+
R = S * cov * S.transpose(-1, -2) # (b, q, q)
340+
return R
341+
342+
def _compute_conditional_means(
343+
self,
344+
X: Tensor,
345+
train_X: Tensor,
346+
train_Y: Tensor,
347+
marginals: BoundedRiemannPosterior,
348+
) -> tuple[Tensor, Tensor]:
349+
"""
350+
Compute conditional means between pairs of points in X.
351+
352+
Conditioning is done with an additional forward pass through the model. The
353+
returned conditional mean will be of shape (b, q, q), with entry [b, i, j] the
354+
conditional mean of j given i set to the conditioning value.
355+
356+
Args:
357+
X: evaluation point, shape (b, q, d)
358+
train_X: Training X, shape (b, n, d)
359+
train_Y: Training Y, shape (b, n, 1)
360+
marginals: A posterior object with marginal posteriors for f(X), but no
361+
correlation structure yet added. posterior.probabilities has
362+
shape (b?, q, num_buckets).
363+
364+
Returns: conditional means (b, q, q), and values used for conditioning (b, q).
365+
"""
366+
b, q, d = X.shape
367+
n = train_X.shape[-2]
368+
post_shape = marginals.probabilities.shape[:-1]
369+
# Find the 90th percentile of each eval point.
370+
cond_val = marginals.icdf(
371+
torch.full(post_shape, 0.9, device=X.device, dtype=X.dtype).unsqueeze(0)
372+
) # (1, b?, q, 1)
373+
cond_val = cond_val.view(b, q) # (b, q)
374+
# Construct conditional training data.
375+
# train_X will have shape (b, q, n+1, d), to have a conditional observation
376+
# for each point. train_Y will have shape (b, q, n+1, 1).
377+
train_X = train_X.unsqueeze(1).expand(b, q, n, d)
378+
cond_X = X.unsqueeze(-2) # (b, q, 1, d)
379+
train_X = torch.cat((train_X, cond_X), dim=-2) # (b, q, n+1, d)
380+
train_Y = train_Y.unsqueeze(1).expand(b, q, n, 1)
381+
cond_Y = cond_val.unsqueeze(-1).unsqueeze(-1) # (b, q, 1, 1)
382+
train_Y = torch.cat((train_Y, cond_Y), dim=-2) # (b, q, n+1, 1)
383+
# Construct eval points
384+
eval_X = X.unsqueeze(1).expand(b, q, q, d)
385+
# Squeeze everything into necessary 2 batch dims, and do PFN forward pass
386+
cond_probabilities = self.pfn_predict(
387+
X=eval_X.reshape(b * q, q, d),
388+
train_X=train_X.reshape(b * q, n + 1, d),
389+
train_Y=train_Y.reshape(b * q, n + 1, 1),
390+
) # (b * q, q, num_buckets)
391+
# Object for conditional posteriors
392+
cond_posterior = BoundedRiemannPosterior(
393+
borders=self.borders,
394+
probabilities=cond_probabilities,
395+
)
396+
# Get conditional means
397+
cond_mean = cond_posterior.mean.squeeze(-1) # (b * q, q)
398+
cond_mean = cond_mean.unsqueeze(0).view(b, q, q)
399+
return cond_mean, cond_val
400+
401+
def _estimate_covariances(
402+
self,
403+
cond_mean: Tensor,
404+
cond_val: Tensor,
405+
mean: Tensor,
406+
var: Tensor,
407+
) -> Tensor:
408+
"""
409+
Estimate covariances from conditional distributions.
410+
411+
Part one: Compute noise variance implied by conditional distributions
412+
E[f_j | y_j=y] = E[f_j] + var[f_j]/(var[f_j] + noise_var) * (y - E[f_j])
413+
Let Z_jj = (E[f_j | y_j=y] - E[f_j]) / (y - E[f_j]).
414+
Note that Z is in (0, 1].
415+
Then, noise_var_j = var[f_j](1/Z_jj - 1).
416+
417+
Part two: Compute covariances for all pairs
418+
E[f_j|y_i=y] = E[f_j]+cov[f_j, f_i]/(var[f_i] + noise_var_i) * (y - E[f_i])
419+
Let Z_ij = (E[f_j | y_i=y] - E[f_j]) / (y - E[f_i]).
420+
Then, cov[f_j, f_i] = Z * (var[f_i] + noise_var)
421+
422+
Args:
423+
cond_mean: (b, q, q) means of dim -1 conditioned on dim -2
424+
cond_val: (b, q) conditioned y value.
425+
var: (b, q) marginal variances
426+
mean: (b, q) marginal means
427+
428+
Returns: Covariance matrix
429+
"""
430+
Z = (cond_mean - mean.unsqueeze(-2).expand(cond_mean.shape)) / (
431+
cond_val - mean
432+
).unsqueeze(-1) # (b, q, q)
433+
# Z[i, j] is for j cond. on i
434+
noise_var = torch.clamp(
435+
var * (1 / torch.diagonal(Z, dim1=-2, dim2=-1) - 1), min=1e-8
436+
) # (b, q)
437+
cov = Z * (var + noise_var).unsqueeze(-1) # (b, q, q)
438+
# Symmetrize
439+
cov = 0.5 * (cov + cov.transpose(-1, -2))
440+
cov = self._map_psd(cov)
441+
return cov
442+
443+
def _map_psd(self, A):
444+
"""
445+
Map A (assumed symmetric) to the nearest PSD matrix.
446+
"""
447+
if torch.linalg.eigvals(A).real.min() < 0:
448+
L, Q = torch.linalg.eigh(A)
449+
L = torch.clamp(L, min=1e-6)
450+
A = Q @ torch.diag_embed(L) @ Q.transpose(-1, -2)
451+
return A

botorch_community/posteriors/riemann.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from botorch.posteriors.posterior import Posterior
1717
from botorch.sampling.get_sampler import _get_sampler_mvn, GetSampler
1818
from botorch.sampling.normal import NormalMCSampler
19+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
1920
from torch import Tensor
2021

2122

@@ -209,6 +210,55 @@ def icdf(
209210
return result
210211

211212

213+
class MultivariateRiemannPosterior(BoundedRiemannPosterior):
214+
def __init__(self, borders, probabilities, correlation_matrix) -> None:
215+
"""
216+
A multi-variate bounded Riemann posterior using a Gaussian copula.
217+
218+
Uses BoundedRiemannPosterior for marginal distributions, and then MVN
219+
correlation structure via the Gaussian copula.
220+
221+
Args:
222+
borders: A tensor of shape `(num_buckets + 1,)` defining the boundaries of
223+
the buckets. Must be monotonically increasing.
224+
probabilities: A tensor of shape `(b?, q, num_buckets)` defining the
225+
probability mass in each bucket. Must sum to 1 in the last dim.
226+
correlation_matrix: The Guassian correlation matrix, (b?, q, q).
227+
"""
228+
super().__init__(borders=borders, probabilities=probabilities)
229+
self.correlation_matrix = correlation_matrix
230+
231+
def rsample_from_base_samples(
232+
self, sample_shape: torch.Size, base_samples: Tensor
233+
) -> Tensor:
234+
"""
235+
Sample from the posterior using base samples.
236+
237+
base_samples are N(0, I) samples, as this posterior is registered
238+
with the IIDNormalSampler below. This is also necessary for the use
239+
of the Gaussian copula.
240+
241+
Args:
242+
sample_shape: Shape of samples.
243+
base_samples: (nsamp, b?, q) standard normal samples.
244+
245+
Returns: Samples from copula, shape (nsamp, b?, q, 1)
246+
"""
247+
# Construct MVN
248+
mvn = MultivariateNormal(
249+
mean=torch.zeros_like(self.mean.squeeze(-1)),
250+
covariance_matrix=self.correlation_matrix,
251+
)
252+
# Draw samples
253+
samples = mvn.rsample(
254+
sample_shape=sample_shape, base_samples=base_samples
255+
).squeeze(-1) # (nsamp, b?, q)
256+
# Convert N(0, 1) marginals to Uniform samples.
257+
U = torch.distributions.Normal(0, 1).cdf(samples) # (nsamp, b?, q)
258+
Z = self.icdf(U) # (nsamp, b?, q, 1)
259+
return Z
260+
261+
212262
@GetSampler.register(BoundedRiemannPosterior)
213263
def _get_sampler_riemann(
214264
posterior: BoundedRiemannPosterior,

0 commit comments

Comments
 (0)