Skip to content

Commit f916855

Browse files
committed
Improve numerical precision in MVNScore.log_prob
1 parent e329f4b commit f916855

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

bayesflow/scores/multivariate_normal_score.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282
"""
8383
diff = x - mean
8484

85-
# Calculate covariance from Cholesky factors
86-
covariance = keras.ops.matmul(
87-
cov_chol,
88-
keras.ops.swapaxes(cov_chol, -2, -1),
85+
# Calculate precision from Cholesky factors of covariance matrix
86+
cov_chol_inv = keras.ops.inv(cov_chol)
87+
precision = keras.ops.matmul(
88+
keras.ops.swapaxes(cov_chol_inv, -2, -1),
89+
cov_chol_inv,
8990
)
90-
precision = keras.ops.inv(covariance)
91-
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
91+
92+
# Compute log determinant, exploiting Cholesky factors
93+
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
9294

9395
# Compute the quadratic term in the exponential of the multivariate Gaussian
9496
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)

0 commit comments

Comments
 (0)