Skip to content

Stabilize MultivariateNormalScore by constraining initialization in PositiveDefinite link #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions bayesflow/links/positive_definite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras

from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs, fill_triangular_matrix
from bayesflow.utils import layer_kwargs, fill_triangular_matrix, positive_diag
from bayesflow.utils.serialization import serializable


Expand All @@ -11,16 +11,21 @@ class PositiveDefinite(keras.Layer):

def __init__(self, **kwargs):
super().__init__(**layer_kwargs(kwargs))
self.built = True

self.layer_norm = keras.layers.LayerNormalization()

def call(self, inputs: Tensor) -> Tensor:
# Build cholesky factor from inputs
L = fill_triangular_matrix(inputs, positive_diag=True)
# normalize the activation at initialization time mean = 0.0, std = 0.1
inputs = self.layer_norm(inputs) / 10

# form a cholesky factor
L = fill_triangular_matrix(inputs)
L = positive_diag(L)

# calculate positive definite matrix from cholesky factors
# calculate positive definite matrix from cholesky factors:
psd = keras.ops.matmul(
L,
keras.ops.moveaxis(L, -2, -1), # L transposed
keras.ops.swapaxes(L, -2, -1), # L transposed
)
return psd

Expand All @@ -31,13 +36,14 @@ def compute_output_shape(self, input_shape):

def compute_input_shape(self, output_shape):
"""
Returns the shape of parameterization of a cholesky factor triangular matrix.
Returns the shape of parameterization of a Cholesky factor triangular matrix.

There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
There are :math:`m` nonzero elements of a lower triangular :math:`n \\times n` matrix with
:math:`m = n (n + 1) / 2`, so for output shape (..., n, n) the returned shape is (..., m).

Example
-------
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
Examples
--------
>>> PositiveDefinite().compute_input_shape((None, 3, 3))
6
"""
n = output_shape[-1]
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@
expand_right_as,
expand_right_to,
expand_tile,
fill_triangular_matrix,
pad,
positive_diag,
searchsorted,
size_of,
stack_valid,
tile_axis,
tree_concatenate,
tree_stack,
fill_triangular_matrix,
weighted_mean,
)

Expand Down
105 changes: 63 additions & 42 deletions bayesflow/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@
Batch of flattened nonzero matrix elements for triangular matrix.
upper : bool
Return upper triangular matrix if True, else lower triangular matrix. Default is False.
positive_diag : bool
Whether to apply a softplus operation to diagonal elements. Default is False.

Returns
-------
Expand All @@ -327,47 +325,70 @@
batch_shape = x.shape[:-1]
m = x.shape[-1]

if m == 1:
y = keras.ops.reshape(x, (-1, 1, 1))
if positive_diag:
y = keras.activations.softplus(y)
return y

# Calculate matrix shape
n = (0.25 + 2 * m) ** 0.5 - 0.5
if not np.isclose(np.floor(n), n):
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")
else:
n = int(n)

# Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
x_tail = keras.ops.take(x, indices=list(range((m - (n**2 - m)), x.shape[-1])), axis=-1)
if not upper:
y = keras.ops.concatenate([x_tail, keras.ops.flip(x, axis=-1)], axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.tril(y)

if positive_diag:
y_offdiag = keras.ops.tril(y, k=-1)
# carve out diagonal, by setting upper and lower offdiagonals to zero
y_diag = keras.ops.tril(
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
if m > 1: # Matrix is larger than than 1x1
# Calculate matrix shape
n = (0.25 + 2 * m) ** 0.5 - 0.5
if not np.isclose(np.floor(n), n):
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")

Check warning on line 332 in bayesflow/utils/tensor_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tensor_utils.py#L332

Added line #L332 was not covered by tests
else:
n = int(n)

# Trick: Create triangular matrix by concatenating with a flipped version of itself, then reshape.
if not upper:
x_list = [x, keras.ops.flip(x[..., n:], axis=-1)]

y = keras.ops.concatenate(x_list, axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.tril(y)

else:
x_list = [x[..., n:], keras.ops.flip(x, axis=-1)]

Check warning on line 345 in bayesflow/utils/tensor_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tensor_utils.py#L345

Added line #L345 was not covered by tests

y = keras.ops.concatenate(x_list, axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.triu(

Check warning on line 349 in bayesflow/utils/tensor_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tensor_utils.py#L347-L349

Added lines #L347 - L349 were not covered by tests
y,
)
y = y_diag + y_offdiag

else:
y = keras.ops.concatenate([x, keras.ops.flip(x_tail, axis=-1)], axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.triu(
y,
)

if positive_diag:
y_offdiag = keras.ops.triu(y, k=1)
# carve out diagonal, by setting upper and lower offdiagonals to zero
y_diag = keras.ops.tril(
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
)
y = y_diag + y_offdiag
else: # Matrix is 1x1
y = keras.ops.reshape(x, (-1, 1, 1))

Check warning on line 354 in bayesflow/utils/tensor_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tensor_utils.py#L354

Added line #L354 was not covered by tests

return y


def positive_diag(x: Tensor, method="default") -> Tensor:
"""
Ensures that matrix elements on diagonal are positive.

Parameters
----------
x : Tensor of shape (batch_size, n, n)
Batch of matrices.
method : str, optional
Method by which to ensure positivity of diagonal entries. Choose from
- "shifted_softplus": softplus(x + 0.5413)
- "exp": exp(x)
Both methods map a matrix filled with zeros to the unit matrix.
Default is "shifted_softplus".

Returns
-------
Tensor of shape (batch_size, n, n)
"""
# ensure positivity
match method:
case "default" | "shifted_softplus":
x_positive = keras.activations.softplus(x + 0.5413)
case "exp":
x_positive = keras.ops.exp(x)

Check warning on line 383 in bayesflow/utils/tensor_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tensor_utils.py#L382-L383

Added lines #L382 - L383 were not covered by tests

# zero all offdiagonals
x_diag_positive = keras.ops.tril(keras.ops.triu(x_positive))

# zero diagonal entries
x_offdiag = keras.ops.triu(x, k=1) + keras.ops.tril(x, k=-1)

# sum to get full matrices with softplus applied only to diagonal entries
x = x_diag_positive + x_offdiag

return x