Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bayesflow/networks/standardization/standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def moving_std(self, index: int) -> Tensor:
"""
return keras.ops.where(
self.moving_m2[index] > 0,
keras.ops.sqrt(self.moving_m2[index] / self.count),
keras.ops.sqrt(self.moving_m2[index] / self.count[index]),
1.0,
)

Expand All @@ -53,7 +53,7 @@ def build(self, input_shape: Shape):
self.moving_m2 = [
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
]
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
self.count = [self.add_weight(shape=(), initializer="zeros", trainable=False) for _ in flattened_shapes]

def call(
self,
Expand Down Expand Up @@ -150,7 +150,7 @@ def _update_moments(self, x: Tensor, index: int):
"""

reduce_axes = tuple(range(x.ndim - 1))
batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype)
batch_count = keras.ops.cast(keras.ops.prod(keras.ops.shape(x)[:-1]), self.count[index].dtype)

# Compute batch mean and M2 per feature
batch_mean = keras.ops.mean(x, axis=reduce_axes)
Expand All @@ -159,7 +159,7 @@ def _update_moments(self, x: Tensor, index: int):
# Read current totals
mean = self.moving_mean[index]
m2 = self.moving_m2[index]
count = self.count
count = self.count[index]

total_count = count + batch_count
delta = batch_mean - mean
Expand All @@ -169,4 +169,4 @@ def _update_moments(self, x: Tensor, index: int):

self.moving_mean[index].assign(new_mean)
self.moving_m2[index].assign(new_m2)
self.count.assign(total_count)
self.count[index].assign(total_count)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
approximator.build(data_shapes)
for layer in approximator.standardize_layers.values():
assert layer.built
assert layer.count == 0
for count in layer.count:
assert count == 0.0
approximator.compute_metrics(**train_dataset[0])

keras.saving.save_model(approximator, tmp_path / "model.keras")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_approximators/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ def test_build(approximator, simulator, batch_size, adapter):
approximator.build(batch_shapes)
for layer in approximator.standardize_layers.values():
assert layer.built
assert layer.count == 0
for count in layer.count:
assert count == 0.0
33 changes: 33 additions & 0 deletions tests/test_networks/test_standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,39 @@ def test_nested_consistency_forward_inverse():
np.testing.assert_allclose(random_input["b"], recovered["b"], atol=1e-4)


def test_nested_accuracy_forward():
from bayesflow.utils import tree_concatenate

# create inputs for two training passes
random_input_a_1 = keras.random.normal((2, 3, 5))
random_input_b_1 = keras.random.normal((4, 3))
random_input_1 = {"a": random_input_a_1, "b": random_input_b_1}

random_input_a_2 = keras.random.normal((3, 3, 5))
random_input_b_2 = keras.random.normal((3, 3))
random_input_2 = {"a": random_input_a_2, "b": random_input_b_2}

# complete data for testing mean and std are 0 and 1
random_input = tree_concatenate([random_input_1, random_input_2], axis=0)

layer = Standardization()

_ = layer(random_input_1, stage="training", forward=True)
_ = layer(random_input_2, stage="training", forward=True)

standardized = layer(random_input, stage="inference", forward=True)
standardized = keras.tree.map_structure(keras.ops.convert_to_numpy, standardized)

np.testing.assert_allclose(
np.mean(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 0.0, atol=1e-4
)
np.testing.assert_allclose(
np.mean(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 0.0, atol=1e-4
)
np.testing.assert_allclose(np.std(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 1.0, atol=1e-4)
np.testing.assert_allclose(np.std(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 1.0, atol=1e-4)


def test_transformation_type_both_sides_scale():
# Fix a known covariance and mean in original (not standardized space)
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")
Expand Down