Skip to content

Commit 361fa45

Browse files
authored
Add Nnpe adapter class (#488)
* Add NNPE adapter * Add NNPE adapter tests * Only apply NNPE during training * Integrate stage differentiation into tests * Improve test coverage * Fix inverse and add to tests * Adjust class name and add docstring to forward method * Enable compatibility with #486 by adjusting scales automatically * Add dimensionwise noise application * Update exception handling * Fix tests
1 parent e13c944 commit 361fa45

File tree

4 files changed

+302
-0
lines changed

4 files changed

+302
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Keep,
1919
Log,
2020
MapTransform,
21+
NNPE,
2122
NumpyTransform,
2223
OneHot,
2324
Rename,
@@ -699,6 +700,43 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
699700
self.transforms.append(transform)
700701
return self
701702

703+
def nnpe(
704+
self,
705+
keys: str | Sequence[str],
706+
*,
707+
spike_scale: float | None = None,
708+
slab_scale: float | None = None,
709+
per_dimension: bool = True,
710+
seed: int | None = None,
711+
):
712+
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
713+
714+
Parameters
715+
----------
716+
keys : str or Sequence of str
717+
The names of the variables to transform.
718+
spike_scale : float or np.ndarray or None, default=None
719+
The scale of the spike (Normal) distribution. Automatically determined if None.
720+
slab_scale : float or np.ndarray or None, default=None
721+
The scale of the slab (Cauchy) distribution. Automatically determined if None.
722+
per_dimension : bool, default=True
723+
If true, noise is applied per dimension of the last axis of the input data.
724+
If false, noise is applied globally.
725+
seed : int or None
726+
The seed for the random number generator. If None, a random seed is used.
727+
"""
728+
if isinstance(keys, str):
729+
keys = [keys]
730+
731+
transform = MapTransform(
732+
{
733+
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
734+
for key in keys
735+
}
736+
)
737+
self.transforms.append(transform)
738+
return self
739+
702740
def one_hot(self, keys: str | Sequence[str], num_classes: int):
703741
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
704742

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .keep import Keep
1313
from .log import Log
1414
from .map_transform import MapTransform
15+
from .nnpe import NNPE
1516
from .numpy_transform import NumpyTransform
1617
from .one_hot import OneHot
1718
from .rename import Rename

bayesflow/adapters/transforms/nnpe.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import numpy as np
2+
3+
from bayesflow.utils.serialization import serializable, serialize
4+
5+
from .elementwise_transform import ElementwiseTransform
6+
7+
8+
@serializable("bayesflow.adapters")
9+
class NNPE(ElementwiseTransform):
10+
"""Implements noisy neural posterior estimation (NNPE) as described in [1], which adds noise following a
11+
spike-and-slab distribution to the training data as a mild form of data augmentation to robustify against noisy
12+
real-world data (see [1, 2] for benchmarks). Adds the options of automatic noise scale determination and
13+
dimensionwise noise application to the original implementation in [1] to provide more flexibility in dealing with
14+
unstandardized and heterogeneous data.
15+
16+
[1] Ward, D., Cannon, P., Beaumont, M., Fasiolo, M., & Schmon, S. (2022). Robust neural posterior estimation and
17+
statistical model criticism. Advances in Neural Information Processing Systems, 35, 33845-33859.
18+
[2] Elsemüller, L., Pratz, V., von Krause, M., Voss, A., Bürkner, P. C., & Radev, S. T. (2025). Does Unsupervised
19+
Domain Adaptation Improve the Robustness of Amortized Bayesian Inference? A Systematic Evaluation. arXiv preprint
20+
arXiv:2502.04949.
21+
22+
Parameters
23+
----------
24+
spike_scale : float or np.ndarray or None, default=None
25+
The scale of the spike (Normal) distribution. Automatically determined if None (see “Notes” section).
26+
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
27+
slab_scale : float or np.ndarray or None, default=None
28+
The scale of the slab (Cauchy) distribution. Automatically determined if None (see “Notes” section).
29+
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
30+
per_dimension : bool, default=True
31+
If true, noise is applied per dimension of the last axis of the input data. If false, noise is applied globally.
32+
Thus, if per_dimension=True, any provided scales must be arrays with shape (n_dimensions,) and automatic
33+
scale determination occurs separately per dimension. If per_dimension=False, provided scales must be floats and
34+
automatic scale determination occurs globally. The original implementation in [1] uses global application
35+
(i.e., per_dimension=False), whereas dimensionwise is recommended if the data dimensions are heterogeneous.
36+
seed : int or None
37+
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
38+
here to enable easy serialization.
39+
40+
Notes
41+
-----
42+
The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
43+
(slab), which are applied based on a Bernoulli random variable with p=0.5.
44+
45+
The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
46+
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
47+
For automatic determination, the standard deviation is determined either globally (if `per_dimension=False`) or per
48+
dimension of the last axis of the input data (if `per_dimension=True`). Note that automatic scale determination is
49+
applied batch-wise in the forward method, which means that determined scales can vary between batches due to varying
50+
standard deviations in the batch input data.
51+
52+
The original implementation in [1] can be recovered by applying the following settings on standardized data:
53+
- `spike_scale=0.01`
54+
- `slab_scale=0.25`
55+
- `per_dimension=False`
56+
57+
Examples
58+
--------
59+
>>> adapter = bf.Adapter().nnpe(["x"])
60+
"""
61+
62+
DEFAULT_SPIKE = 0.01
63+
DEFAULT_SLAB = 0.25
64+
65+
def __init__(
66+
self,
67+
*,
68+
spike_scale: float | np.ndarray | None = None,
69+
slab_scale: float | np.ndarray | None = None,
70+
per_dimension: bool = True,
71+
seed: int | None = None,
72+
):
73+
super().__init__()
74+
self.spike_scale = spike_scale
75+
self.slab_scale = slab_scale
76+
self.per_dimension = per_dimension
77+
self.seed = seed
78+
self.rng = np.random.default_rng(seed)
79+
80+
def _resolve_scale(
81+
self,
82+
name: str,
83+
passed: float | np.ndarray | None,
84+
default: float,
85+
data: np.ndarray,
86+
) -> np.ndarray | float:
87+
"""
88+
Determine spike/slab scale:
89+
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90+
- Else: validate & cast passed to the correct shape/type.
91+
92+
Parameters
93+
----------
94+
name : str
95+
Identifier for error messages (e.g., 'spike_scale' or 'slab_scale').
96+
passed : float or np.ndarray or None
97+
User-specified scale. If None, compute as default * std(data).
98+
If self.per_dimension is True, this may be a 1D array of length data.shape[-1].
99+
default : float
100+
Default multiplier from [1] to apply to the standard deviation of the data.
101+
data : np.ndarray
102+
Data array to compute standard deviation from.
103+
104+
Returns
105+
-------
106+
float or np.ndarray
107+
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
108+
(if per_dimension=True).
109+
"""
110+
111+
# Get std and (expected shape) dimensionwise or globally
112+
if self.per_dimension:
113+
axes = tuple(range(data.ndim - 1))
114+
std = np.std(data, axis=axes)
115+
expected_shape = (data.shape[-1],)
116+
else:
117+
std = np.std(data)
118+
expected_shape = None
119+
120+
# If no scale is passed, determine scale automatically given the dimensionwise or global std
121+
if passed is None:
122+
return default * std
123+
# If a scale is passed, check if the passed shape matches the expected shape
124+
else:
125+
if self.per_dimension:
126+
arr = np.asarray(passed, dtype=float)
127+
if arr.shape != expected_shape or arr.ndim != 1:
128+
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
129+
return arr
130+
else:
131+
try:
132+
scalar = float(passed)
133+
except TypeError:
134+
raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}")
135+
except ValueError:
136+
raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}")
137+
return scalar
138+
139+
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
140+
"""
141+
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
142+
“Notes” section of the class docstring for details).
143+
144+
Parameters
145+
----------
146+
data : np.ndarray
147+
Input array to be perturbed.
148+
stage : str, default='inference'
149+
If 'training', noise is added; else data is returned unchanged.
150+
**kwargs
151+
Unused keyword arguments.
152+
153+
Returns
154+
-------
155+
np.ndarray
156+
Noisy data when `stage` is 'training', otherwise the original input.
157+
"""
158+
if stage != "training":
159+
return data
160+
161+
# Check data validity
162+
if not np.all(np.isfinite(data)):
163+
raise ValueError("NNPE.forward: `data` contains NaN or infinite values.")
164+
165+
spike_scale = self._resolve_scale("spike_scale", self.spike_scale, self.DEFAULT_SPIKE, data)
166+
slab_scale = self._resolve_scale("slab_scale", self.slab_scale, self.DEFAULT_SLAB, data)
167+
168+
# Apply spike-and-slab noise
169+
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
170+
noise_spike = self.rng.standard_normal(size=data.shape) * spike_scale
171+
noise_slab = self.rng.standard_cauchy(size=data.shape) * slab_scale
172+
noise = np.where(mixture_mask, noise_slab, noise_spike)
173+
return data + noise
174+
175+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
176+
"""Non-invertible transform."""
177+
return data
178+
179+
def get_config(self) -> dict:
180+
return serialize(
181+
{
182+
"spike_scale": self.spike_scale,
183+
"slab_scale": self.slab_scale,
184+
"per_dimension": self.per_dimension,
185+
"seed": self.seed,
186+
}
187+
)

tests/test_adapters/test_adapters.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,79 @@ def test_log_det_jac_exceptions(random_data):
296296

297297
# inverse works when concatenation is used after transforms
298298
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)
299+
300+
301+
def test_nnpe(random_data):
302+
# NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data
303+
# and therefore breaks existing allclose checks
304+
import numpy as np
305+
from bayesflow.adapters import Adapter
306+
307+
# Test basic case with global noise application
308+
ad = Adapter().nnpe("x1", spike_scale=1.0, slab_scale=1.0, per_dimension=False, seed=42)
309+
result_training = ad(random_data, stage="training")
310+
result_validation = ad(random_data, stage="validation")
311+
result_inference = ad(random_data, stage="inference")
312+
result_inversed = ad(random_data, inverse=True)
313+
serialized = serialize(ad)
314+
deserialized = deserialize(serialized)
315+
reserialized = serialize(deserialized)
316+
317+
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
318+
319+
# check that only x1 is changed
320+
assert "x1" in result_training
321+
assert not np.allclose(result_training["x1"], random_data["x1"])
322+
323+
# all other keys are untouched
324+
for k, v in random_data.items():
325+
if k == "x1":
326+
continue
327+
assert np.allclose(result_training[k], v)
328+
329+
# check that the validation and inference data as well as inversed results are unchanged
330+
for k, v in random_data.items():
331+
assert np.allclose(result_validation[k], v)
332+
assert np.allclose(result_inference[k], v)
333+
assert np.allclose(result_inversed[k], v)
334+
335+
# Test both scales and seed are None case (automatic scale determination) with dimensionwise noise application
336+
ad_auto = Adapter().nnpe("y1", slab_scale=None, spike_scale=None, per_dimension=True, seed=None)
337+
result_training_auto = ad_auto(random_data, stage="training")
338+
assert not np.allclose(result_training_auto["y1"], random_data["y1"])
339+
for k, v in random_data.items():
340+
if k == "y1":
341+
continue
342+
assert np.allclose(result_training_auto[k], v)
343+
344+
serialized_auto = serialize(ad_auto)
345+
deserialized_auto = deserialize(serialized_auto)
346+
reserialized_auto = serialize(deserialized_auto)
347+
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))
348+
349+
# Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False)
350+
# Create data with second dimension having higher variance
351+
data_shape = (32, 16, 1)
352+
rng = np.random.default_rng(42)
353+
zero = np.ones(shape=data_shape)
354+
high = rng.normal(0, 100.0, size=data_shape)
355+
var_data = {"x": np.concatenate([zero, high], axis=-1)}
356+
357+
# Apply dimensionwise and global adapters with automatic slab_scale scale determination
358+
ad_partial_global = Adapter().nnpe("x", spike_scale=0, slab_scale=None, per_dimension=False, seed=42)
359+
ad_partial_dim = Adapter().nnpe("x", spike_scale=[0, 1], slab_scale=None, per_dimension=True, seed=42)
360+
res_dim = ad_partial_dim(var_data, stage="training")
361+
res_glob = ad_partial_global(var_data, stage="training")
362+
363+
# Compute standard deviations of noise per last axis dimension
364+
noise_dim = res_dim["x"] - var_data["x"]
365+
noise_glob = res_glob["x"] - var_data["x"]
366+
std_dim = np.std(noise_dim, axis=(0, 1))
367+
std_glob = np.std(noise_glob, axis=(0, 1))
368+
369+
# Dimensionwise should assign zero noise, global some noise to zero-variance dimension
370+
assert std_dim[0] == 0
371+
assert std_glob[0] > 0
372+
# Both should assign noise to high-variance dimension
373+
assert std_dim[1] > 0
374+
assert std_glob[1] > 0

0 commit comments

Comments
 (0)