3
3
import keras
4
4
5
5
from bayesflow .types import Shape , Tensor
6
- from bayesflow .links import PositiveDefinite
6
+ from bayesflow .links import CholeskyFactor
7
7
from bayesflow .utils .serialization import serializable
8
8
9
9
from .parametric_distribution_score import ParametricDistributionScore
13
13
class MultivariateNormalScore (ParametricDistributionScore ):
14
14
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
15
15
16
- Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value.
16
+ Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
17
+ of the materialized value.
17
18
"""
18
19
19
- NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("covariance " ,)
20
+ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol " ,)
20
21
"""
21
- Marks head for covariance matrix as an exception for adapter transformations.
22
+ Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
22
23
23
24
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
24
25
in inverse direction to them.
25
26
26
27
For more information see :py:class:`ScoringRule`.
27
28
"""
28
29
29
- TRANSFORMATION_TYPE : dict [str , str ] = {"covariance " : "both_sides_scale " }
30
+ TRANSFORMATION_TYPE : dict [str , str ] = {"cov_chol " : "left_side_scale " }
30
31
"""
31
- Marks covariance head to handle de-standardization as for covariant rank-(0,2) tensors.
32
+ Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
32
33
33
34
The appropriate inverse of the standardization operation is
34
35
35
- x_ij = x_ij' * sigma_i * sigma_j .
36
+ x_ij = sigma_i * x_ij' .
36
37
37
38
For the mean head the default ("location_scale") is not overridden.
38
39
"""
@@ -41,7 +42,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
41
42
super ().__init__ (links = links , ** kwargs )
42
43
43
44
self .dim = dim
44
- self .links = links or {"covariance " : PositiveDefinite ()}
45
+ self .links = links or {"cov_chol " : CholeskyFactor ()}
45
46
46
47
self .config = {"dim" : dim }
47
48
@@ -51,14 +52,14 @@ def get_config(self):
51
52
52
53
def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
53
54
self .dim = target_shape [- 1 ]
54
- return dict (mean = (self .dim ,), covariance = (self .dim , self .dim ))
55
+ return dict (mean = (self .dim ,), cov_chol = (self .dim , self .dim ))
55
56
56
- def log_prob (self , x : Tensor , mean : Tensor , covariance : Tensor ) -> Tensor :
57
+ def log_prob (self , x : Tensor , mean : Tensor , cov_chol : Tensor ) -> Tensor :
57
58
"""
58
59
Compute the log probability density of a multivariate Gaussian distribution.
59
60
60
61
This function calculates the log probability density for each sample in `x` under a
61
- multivariate Gaussian distribution with the given `mean` and `covariance `.
62
+ multivariate Gaussian distribution with the given `mean` and `cov_chol `.
62
63
63
64
The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
64
65
form in the exponential term of the Gaussian density function.
@@ -80,6 +81,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
80
81
given Gaussian distribution.
81
82
"""
82
83
diff = x - mean
84
+
85
+ # Calculate covariance from Cholesky factors
86
+ covariance = keras .ops .matmul (
87
+ cov_chol ,
88
+ keras .ops .swapaxes (cov_chol , - 2 , - 1 ),
89
+ )
83
90
precision = keras .ops .inv (covariance )
84
91
log_det_covariance = keras .ops .slogdet (covariance )[1 ] # Only take the log of the determinant part
85
92
@@ -91,14 +98,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
91
98
92
99
return log_prob
93
100
94
- def sample (self , batch_shape : Shape , mean : Tensor , covariance : Tensor ) -> Tensor :
101
+ def sample (self , batch_shape : Shape , mean : Tensor , cov_chol : Tensor ) -> Tensor :
95
102
"""
96
103
Generate samples from a multivariate Gaussian distribution.
97
104
98
- This function samples from a multivariate Gaussian distribution with the given `mean`
99
- and `covariance` using the Cholesky decomposition method. Independent standard normal
100
- samples are transformed using the Cholesky factor of the covariance matrix to generate
101
- correlated samples.
105
+ Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
106
+ to generate correlated samples.
102
107
103
108
Parameters
104
109
----------
@@ -107,8 +112,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
107
112
mean : Tensor
108
113
A tensor representing the mean of the multivariate Gaussian distribution.
109
114
Must have shape (batch_size, D), where D is the dimensionality of the distribution.
110
- covariance : Tensor
111
- A tensor representing the covariance matrix of the multivariate Gaussian distribution.
115
+ cov_chol : Tensor
116
+ A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
112
117
Must have shape (batch_size, D, D), where D is the dimensionality.
113
118
114
119
Returns
@@ -123,16 +128,16 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
123
128
if keras .ops .shape (mean ) != (batch_size , dim ):
124
129
raise ValueError (f"mean must have shape (batch_size, { dim } ), but got { keras .ops .shape (mean )} " )
125
130
126
- if keras .ops .shape (covariance ) != (batch_size , dim , dim ):
131
+ if keras .ops .shape (cov_chol ) != (batch_size , dim , dim ):
127
132
raise ValueError (
128
- f"covariance must have shape (batch_size, { dim } , { dim } ), but got { keras .ops .shape (covariance )} "
133
+ f"covariance Cholesky factor must have shape (batch_size, { dim } , { dim } ),"
134
+ f"but got { keras .ops .shape (cov_chol )} "
129
135
)
130
136
131
137
# Use Cholesky decomposition to generate samples
132
- cholesky_factor = keras .ops .cholesky (covariance )
133
138
normal_samples = keras .random .normal ((* batch_shape , dim ))
134
139
135
- scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cholesky_factor , normal_samples )
140
+ scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cov_chol , normal_samples )
136
141
samples = mean [:, None , :] + scaled_normal
137
142
138
143
return samples
0 commit comments