Skip to content

Commit 55d51df

Browse files
authored
Breaking changes: Fix bugs regarding counts in standardization layer (#525)
* standardization: add test for multi-input values (failing) This test reveals to bugs in the standarization layer: - count is updated multiple times - batch_count is too small, as the sizes from reduce_axes have to be multiplied * breaking: fix bugs regarding count in standardization layer Fixes #524 This fixes the two bugs described in c4cc133: - count was accidentally updated, leading to wrong values - count was calculated wrongly, as only the batch size was used. Correct is the product of all reduce dimensions. This lead to wrong standard deviations While the batch dimension is the same for all inputs, the size of the second dimension might vary. For this reason, we need to introduce an input-specific `count` variable. This breaks serialization. * fix assert statement in test
1 parent 17540b1 commit 55d51df

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def moving_std(self, index: int) -> Tensor:
4040
"""
4141
return keras.ops.where(
4242
self.moving_m2[index] > 0,
43-
keras.ops.sqrt(self.moving_m2[index] / self.count),
43+
keras.ops.sqrt(self.moving_m2[index] / self.count[index]),
4444
1.0,
4545
)
4646

@@ -53,7 +53,7 @@ def build(self, input_shape: Shape):
5353
self.moving_m2 = [
5454
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
5555
]
56-
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
56+
self.count = [self.add_weight(shape=(), initializer="zeros", trainable=False) for _ in flattened_shapes]
5757

5858
def call(
5959
self,
@@ -150,7 +150,7 @@ def _update_moments(self, x: Tensor, index: int):
150150
"""
151151

152152
reduce_axes = tuple(range(x.ndim - 1))
153-
batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype)
153+
batch_count = keras.ops.cast(keras.ops.prod(keras.ops.shape(x)[:-1]), self.count[index].dtype)
154154

155155
# Compute batch mean and M2 per feature
156156
batch_mean = keras.ops.mean(x, axis=reduce_axes)
@@ -159,7 +159,7 @@ def _update_moments(self, x: Tensor, index: int):
159159
# Read current totals
160160
mean = self.moving_mean[index]
161161
m2 = self.moving_m2[index]
162-
count = self.count
162+
count = self.count[index]
163163

164164
total_count = count + batch_count
165165
delta = batch_mean - mean
@@ -169,4 +169,4 @@ def _update_moments(self, x: Tensor, index: int):
169169

170170
self.moving_mean[index].assign(new_mean)
171171
self.moving_m2[index].assign(new_m2)
172-
self.count.assign(total_count)
172+
self.count[index].assign(total_count)

tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
88
approximator.build(data_shapes)
99
for layer in approximator.standardize_layers.values():
1010
assert layer.built
11-
assert layer.count == 0
11+
for count in layer.count:
12+
assert count == 0.0
1213
approximator.compute_metrics(**train_dataset[0])
1314

1415
keras.saving.save_model(approximator, tmp_path / "model.keras")

tests/test_approximators/test_build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ def test_build(approximator, simulator, batch_size, adapter):
1414
approximator.build(batch_shapes)
1515
for layer in approximator.standardize_layers.values():
1616
assert layer.built
17-
assert layer.count == 0
17+
for count in layer.count:
18+
assert count == 0.0

tests/test_networks/test_standardization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@ def test_nested_consistency_forward_inverse():
9191
np.testing.assert_allclose(random_input["b"], recovered["b"], atol=1e-4)
9292

9393

94+
def test_nested_accuracy_forward():
95+
from bayesflow.utils import tree_concatenate
96+
97+
# create inputs for two training passes
98+
random_input_a_1 = keras.random.normal((2, 3, 5))
99+
random_input_b_1 = keras.random.normal((4, 3))
100+
random_input_1 = {"a": random_input_a_1, "b": random_input_b_1}
101+
102+
random_input_a_2 = keras.random.normal((3, 3, 5))
103+
random_input_b_2 = keras.random.normal((3, 3))
104+
random_input_2 = {"a": random_input_a_2, "b": random_input_b_2}
105+
106+
# complete data for testing mean and std are 0 and 1
107+
random_input = tree_concatenate([random_input_1, random_input_2], axis=0)
108+
109+
layer = Standardization()
110+
111+
_ = layer(random_input_1, stage="training", forward=True)
112+
_ = layer(random_input_2, stage="training", forward=True)
113+
114+
standardized = layer(random_input, stage="inference", forward=True)
115+
standardized = keras.tree.map_structure(keras.ops.convert_to_numpy, standardized)
116+
117+
np.testing.assert_allclose(
118+
np.mean(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 0.0, atol=1e-4
119+
)
120+
np.testing.assert_allclose(
121+
np.mean(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 0.0, atol=1e-4
122+
)
123+
np.testing.assert_allclose(np.std(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 1.0, atol=1e-4)
124+
np.testing.assert_allclose(np.std(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 1.0, atol=1e-4)
125+
126+
94127
def test_transformation_type_both_sides_scale():
95128
# Fix a known covariance and mean in original (not standardized space)
96129
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")

0 commit comments

Comments
 (0)