Skip to content

Commit ef97a6c

Browse files
committed
Replace PositiveDefinite link with CholeskyFactor
This finally makes the MVN score sampling test stable for the jax backend, for which the keras.ops.cholesky operation is numerically unstable. The score's sample method avoids calling keras.ops.cholesky to resolve the issue. Instead the estimation head returns the Cholesky factor directly rather than the covariance matrix (as it used to be).
1 parent 82e28a7 commit ef97a6c

File tree

5 files changed

+48
-49
lines changed

5 files changed

+48
-49
lines changed

bayesflow/links/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .ordered import Ordered
44
from .ordered_quantiles import OrderedQuantiles
5-
from .positive_definite import PositiveDefinite
5+
from .cholesky_factor import CholeskyFactor
66

77
from ..utils._docs import _add_imports_to_all
88

bayesflow/links/positive_definite.py renamed to bayesflow/links/cholesky_factor.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
@serializable("bayesflow.links")
9-
class PositiveDefinite(keras.Layer):
10-
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""
9+
class CholeskyFactor(keras.Layer):
10+
"""Activation function to link from a flat tensor to a lower triangular matrix with positive diagonal."""
1111

1212
def __init__(self, **kwargs):
1313
super().__init__(**layer_kwargs(kwargs))
@@ -17,12 +17,7 @@ def call(self, inputs: Tensor) -> Tensor:
1717
L = fill_triangular_matrix(inputs)
1818
L = positive_diag(L)
1919

20-
# calculate positive definite matrix from cholesky factors:
21-
psd = keras.ops.matmul(
22-
L,
23-
keras.ops.swapaxes(L, -2, -1), # L transposed
24-
)
25-
return psd
20+
return L
2621

2722
def compute_output_shape(self, input_shape):
2823
m = input_shape[-1]

bayesflow/scores/multivariate_normal_score.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import keras
44

55
from bayesflow.types import Shape, Tensor
6-
from bayesflow.links import PositiveDefinite
6+
from bayesflow.links import CholeskyFactor
77
from bayesflow.utils.serialization import serializable
88

99
from .parametric_distribution_score import ParametricDistributionScore
@@ -13,26 +13,27 @@
1313
class MultivariateNormalScore(ParametricDistributionScore):
1414
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
1515
16-
Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value.
16+
Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
17+
of the materialized value.
1718
"""
1819

19-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("covariance",)
20+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol",)
2021
"""
21-
Marks head for covariance matrix as an exception for adapter transformations.
22+
Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
2223
2324
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
2425
in inverse direction to them.
2526
2627
For more information see :py:class:`ScoringRule`.
2728
"""
2829

29-
TRANSFORMATION_TYPE: dict[str, str] = {"covariance": "both_sides_scale"}
30+
TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol": "left_side_scale"}
3031
"""
31-
Marks covariance head to handle de-standardization as for covariant rank-(0,2) tensors.
32+
Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
3233
3334
The appropriate inverse of the standardization operation is
3435
35-
x_ij = x_ij' * sigma_i * sigma_j.
36+
x_ij = sigma_i * x_ij'.
3637
3738
For the mean head the default ("location_scale") is not overridden.
3839
"""
@@ -41,7 +42,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4142
super().__init__(links=links, **kwargs)
4243

4344
self.dim = dim
44-
self.links = links or {"covariance": PositiveDefinite()}
45+
self.links = links or {"cov_chol": CholeskyFactor()}
4546

4647
self.config = {"dim": dim}
4748

@@ -51,14 +52,14 @@ def get_config(self):
5152

5253
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
5354
self.dim = target_shape[-1]
54-
return dict(mean=(self.dim,), covariance=(self.dim, self.dim))
55+
return dict(mean=(self.dim,), cov_chol=(self.dim, self.dim))
5556

56-
def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
57+
def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
5758
"""
5859
Compute the log probability density of a multivariate Gaussian distribution.
5960
6061
This function calculates the log probability density for each sample in `x` under a
61-
multivariate Gaussian distribution with the given `mean` and `covariance`.
62+
multivariate Gaussian distribution with the given `mean` and `cov_chol`.
6263
6364
The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
6465
form in the exponential term of the Gaussian density function.
@@ -80,6 +81,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
8081
given Gaussian distribution.
8182
"""
8283
diff = x - mean
84+
85+
# Calculate covariance from Cholesky factors
86+
covariance = keras.ops.matmul(
87+
cov_chol,
88+
keras.ops.swapaxes(cov_chol, -2, -1),
89+
)
8390
precision = keras.ops.inv(covariance)
8491
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
8592

@@ -91,14 +98,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
9198

9299
return log_prob
93100

94-
def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor:
101+
def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
95102
"""
96103
Generate samples from a multivariate Gaussian distribution.
97104
98-
This function samples from a multivariate Gaussian distribution with the given `mean`
99-
and `covariance` using the Cholesky decomposition method. Independent standard normal
100-
samples are transformed using the Cholesky factor of the covariance matrix to generate
101-
correlated samples.
105+
Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
106+
to generate correlated samples.
102107
103108
Parameters
104109
----------
@@ -107,8 +112,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
107112
mean : Tensor
108113
A tensor representing the mean of the multivariate Gaussian distribution.
109114
Must have shape (batch_size, D), where D is the dimensionality of the distribution.
110-
covariance : Tensor
111-
A tensor representing the covariance matrix of the multivariate Gaussian distribution.
115+
cov_chol : Tensor
116+
A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
112117
Must have shape (batch_size, D, D), where D is the dimensionality.
113118
114119
Returns
@@ -123,16 +128,16 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
123128
if keras.ops.shape(mean) != (batch_size, dim):
124129
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")
125130

126-
if keras.ops.shape(covariance) != (batch_size, dim, dim):
131+
if keras.ops.shape(cov_chol) != (batch_size, dim, dim):
127132
raise ValueError(
128-
f"covariance must have shape (batch_size, {dim}, {dim}), but got {keras.ops.shape(covariance)}"
133+
f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim}),"
134+
f"but got {keras.ops.shape(cov_chol)}"
129135
)
130136

131137
# Use Cholesky decomposition to generate samples
132-
cholesky_factor = keras.ops.cholesky(covariance)
133138
normal_samples = keras.random.normal((*batch_shape, dim))
134139

135-
scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cholesky_factor, normal_samples)
140+
scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cov_chol, normal_samples)
136141
samples = mean[:, None, :] + scaled_normal
137142

138143
return samples

tests/test_links/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ def ordered_quantiles():
3333

3434

3535
@pytest.fixture()
36-
def positive_definite():
37-
from bayesflow.links import PositiveDefinite
36+
def cholesky_factor():
37+
from bayesflow.links import CholeskyFactor
3838

39-
return PositiveDefinite()
39+
return CholeskyFactor()
4040

4141

4242
@pytest.fixture()
4343
def linear():
4444
return keras.layers.Activation("linear")
4545

4646

47-
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_definite", "linear"], scope="function")
47+
@pytest.fixture(params=["ordered", "ordered_quantiles", "cholesky_factor", "linear"], scope="function")
4848
def link(request):
4949
return request.getfixturevalue(request.param)
5050

tests/test_links/test_links.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,20 @@ def test_quantile_ordering(quantiles, unordered):
5252
check_ordering(output, axis)
5353

5454

55-
def test_positive_definite(positive_definite, batch_size, num_variables):
56-
input_shape = positive_definite.compute_input_shape((batch_size, num_variables, num_variables))
55+
def test_cholesky_factor(cholesky_factor, batch_size, num_variables):
56+
input_shape = cholesky_factor.compute_input_shape((batch_size, num_variables, num_variables))
5757

58-
# Too strongly negative values lead to numerical instabilities -> reduce scale
59-
random_preactivation = keras.random.normal(input_shape) * 0.1
60-
output = positive_definite(random_preactivation)
61-
output = keras.ops.convert_to_numpy(output)
62-
63-
# Check if output is invertible
64-
np.linalg.inv(output)
58+
random_preactivation = keras.random.normal(input_shape)
6559

66-
# Calculated eigenvalues to test for positive definiteness
67-
eigenvalues = np.linalg.eig(output).eigenvalues
60+
output = cholesky_factor(random_preactivation)
61+
output = keras.ops.convert_to_numpy(output)
6862

69-
assert np.all(eigenvalues.real > 0) and np.all(np.isclose(eigenvalues.imag, 0)), (
70-
f"output is not positive definite: min(real)={np.min(eigenvalues.real)}, "
71-
f"max(abs(imag))={np.max(np.abs(eigenvalues.imag))}"
63+
np.testing.assert_allclose(
64+
np.triu(output, k=1),
65+
np.zeros((batch_size, num_variables, num_variables)),
66+
atol=1e-4,
67+
err_msg=f"All elements above diagonal must be zero for lower triangular matrix: {output}",
7268
)
69+
70+
diag = np.diagonal(output, axis1=1, axis2=2)
71+
assert np.all(diag > 0), f"diagonal is not strictly positive: {diag}"

0 commit comments

Comments
 (0)