Skip to content

Commit 62675c3

Browse files
thegialeovpratz
andauthored
Implement Feature for Issue #379: MMD Hypothesis Test (#384)
- add `bootstrap_comparison` and `summary_space_comparison` to enable comparisons of two domains in the data space or the summary space via bootstrapping - add `.summaries()` function for easy access to summaries to `ContinuousApproximator` and `ModelComparisonApproximator` - add tests for the added functionality --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de>
1 parent 5595eab commit 62675c3

File tree

11 files changed

+595
-3
lines changed

11 files changed

+595
-3
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,39 @@ def _sample(
400400
**filter_kwargs(kwargs, self.inference_network.sample),
401401
)
402402

403+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
404+
"""
405+
Computes the summaries of given data.
406+
407+
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
408+
409+
Parameters
410+
----------
411+
data : Mapping[str, np.ndarray]
412+
Dictionary of data as NumPy arrays.
413+
**kwargs : dict
414+
Additional keyword arguments for the adapter and the summary network.
415+
416+
Returns
417+
-------
418+
summaries : np.ndarray
419+
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
420+
421+
Raises
422+
------
423+
ValueError
424+
If the approximator does not have a summary network, or the adapter does not produce the output required
425+
by the summary network.
426+
"""
427+
if self.summary_network is None:
428+
raise ValueError("A summary network is required to compute summeries.")
429+
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
430+
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
431+
raise ValueError("Summary variables are required to compute summaries.")
432+
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
433+
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
434+
return summaries
435+
403436
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
404437
"""
405438
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
345345
output = self.logits_projector(output)
346346

347347
return output
348+
349+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
350+
"""
351+
Computes the summaries of given data.
352+
353+
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
354+
355+
Parameters
356+
----------
357+
data : Mapping[str, np.ndarray]
358+
Dictionary of data as NumPy arrays.
359+
**kwargs : dict
360+
Additional keyword arguments for the adapter and the summary network.
361+
362+
Returns
363+
-------
364+
summaries : np.ndarray
365+
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
366+
367+
Raises
368+
------
369+
ValueError
370+
If the approximator does not have a summary network, or the adapter does not produce the output required
371+
by the summary network.
372+
"""
373+
if self.summary_network is None:
374+
raise ValueError("A summary network is required to compute summaries.")
375+
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
376+
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
377+
raise ValueError("Summary variables are required to compute summaries.")
378+
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
379+
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
380+
return summaries

bayesflow/diagnostics/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
"""
1+
r"""
22
A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s.
33
"""
44

5-
from .metrics import root_mean_squared_error, calibration_error, posterior_contraction
5+
from .metrics import (
6+
bootstrap_comparison,
7+
calibration_error,
8+
posterior_contraction,
9+
summary_space_comparison,
10+
)
611

712
from .plots import (
813
calibration_ecdf,

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .root_mean_squared_error import root_mean_squared_error
44
from .expected_calibration_error import expected_calibration_error
55
from .classifier_two_sample_test import classifier_two_sample_test
6+
from .model_misspecification import bootstrap_comparison, summary_space_comparison
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
This module provides functions for computing distances between observation samples and reference samples with distance
3+
distributions within the reference samples for hypothesis testing.
4+
"""
5+
6+
from collections.abc import Mapping, Callable
7+
8+
import numpy as np
9+
from keras.ops import convert_to_numpy, convert_to_tensor
10+
11+
from bayesflow.approximators import ContinuousApproximator
12+
from bayesflow.metrics.functional import maximum_mean_discrepancy
13+
from bayesflow.types import Tensor
14+
15+
16+
def bootstrap_comparison(
17+
observed_samples: np.ndarray,
18+
reference_samples: np.ndarray,
19+
comparison_fn: Callable[[Tensor, Tensor], Tensor],
20+
num_null_samples: int = 100,
21+
) -> tuple[float, np.ndarray]:
22+
"""Computes the distance between observed and reference samples and generates a distribution of null sample
23+
distances by bootstrapping for hypothesis testing.
24+
25+
Parameters
26+
----------
27+
observed_samples : np.ndarray)
28+
Observed samples, shape (num_observed, ...).
29+
reference_samples : np.ndarray
30+
Reference samples, shape (num_reference, ...).
31+
comparison_fn : Callable[[Tensor, Tensor], Tensor]
32+
Function to compute the distance metric.
33+
num_null_samples : int
34+
Number of null samples to generate for hypothesis testing. Default is 100.
35+
36+
Returns
37+
-------
38+
distance_observed : float
39+
The distance value between observed and reference samples.
40+
distance_null : np.ndarray
41+
A distribution of distance values under the null hypothesis.
42+
43+
Raises
44+
------
45+
ValueError
46+
- If the number of number of observed samples exceeds the number of reference samples
47+
- If the shapes of observed and reference samples do not match on dimensions besides the first one.
48+
"""
49+
num_observed: int = observed_samples.shape[0]
50+
num_reference: int = reference_samples.shape[0]
51+
52+
if num_observed > num_reference:
53+
raise ValueError(
54+
f"Number of observed samples ({num_observed}) cannot exceed"
55+
f"the number of reference samples ({num_reference}) for bootstrapping."
56+
)
57+
if observed_samples.shape[1:] != reference_samples.shape[1:]:
58+
raise ValueError(
59+
f"Expected observed and reference samples to have the same shape, "
60+
f"but got {observed_samples.shape[1:]} != {reference_samples.shape[1:]}."
61+
)
62+
63+
observed_samples_tensor: Tensor = convert_to_tensor(observed_samples, dtype="float32")
64+
reference_samples_tensor: Tensor = convert_to_tensor(reference_samples, dtype="float32")
65+
66+
distance_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64)
67+
for i in range(num_null_samples):
68+
bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed)
69+
bootstrap_samples: np.ndarray = reference_samples[bootstrap_idx]
70+
bootstrap_samples_tensor: Tensor = convert_to_tensor(bootstrap_samples, dtype="float32")
71+
distance_null_samples[i] = convert_to_numpy(comparison_fn(bootstrap_samples_tensor, reference_samples_tensor))
72+
73+
distance_observed_tensor: Tensor = comparison_fn(
74+
observed_samples_tensor,
75+
reference_samples_tensor,
76+
)
77+
78+
distance_observed: float = float(convert_to_numpy(distance_observed_tensor))
79+
80+
return distance_observed, distance_null_samples
81+
82+
83+
def summary_space_comparison(
84+
observed_data: Mapping[str, np.ndarray],
85+
reference_data: Mapping[str, np.ndarray],
86+
approximator: ContinuousApproximator,
87+
num_null_samples: int = 100,
88+
comparison_fn: Callable = maximum_mean_discrepancy,
89+
**kwargs,
90+
) -> tuple[float, np.ndarray]:
91+
"""Computes the distance between observed and reference data in the summary space and
92+
generates a distribution of distance values under the null hypothesis to assess model misspecification.
93+
94+
By default, the Maximum Mean Discrepancy (MMD) is used as a distance function.
95+
96+
[1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian
97+
inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866.
98+
URL: https://arxiv.org/abs/2112.08866
99+
100+
Parameters
101+
----------
102+
observed_data : dict[str, np.ndarray]
103+
Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
104+
through its summary network.
105+
reference_data : dict[str, np.ndarray]
106+
Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
107+
through its summary network.
108+
approximator : ContinuousApproximator
109+
An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics
110+
from the data.
111+
num_null_samples : int, optional
112+
Number of null samples to generate for hypothesis testing. Default is 100.
113+
comparison_fn : Callable, optional
114+
Distance function to compare the data in the summary space.
115+
**kwargs : dict
116+
Additional keyword arguments for the adapter and sampling process.
117+
118+
Returns
119+
-------
120+
distance_observed : float
121+
The MMD value between observed and reference summaries.
122+
distance_null : np.ndarray
123+
A distribution of MMD values under the null hypothesis.
124+
125+
Raises
126+
------
127+
ValueError
128+
If approximator is not an instance of ContinuousApproximator or does not have a summary network.
129+
"""
130+
131+
if not isinstance(approximator, ContinuousApproximator):
132+
raise ValueError("The approximator must be an instance of ContinuousApproximator.")
133+
134+
if not hasattr(approximator, "summary_network") or approximator.summary_network is None:
135+
comparison_fn_name = (
136+
"bayesflow.metrics.functional.maximum_mean_discrepancy"
137+
if comparison_fn is maximum_mean_discrepancy
138+
else comparison_fn.__name__
139+
)
140+
raise ValueError(
141+
"The approximator must have a summary network. If you have manually crafted summary "
142+
"statistics, or want to compare raw data and not summary statistics, please use the "
143+
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
144+
)
145+
observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
146+
reference_summaries = convert_to_numpy(approximator.summaries(reference_data))
147+
148+
distance_observed, distance_null = bootstrap_comparison(
149+
observed_samples=observed_summaries,
150+
reference_samples=reference_summaries,
151+
comparison_fn=comparison_fn,
152+
num_null_samples=num_null_samples,
153+
)
154+
155+
return distance_observed, distance_null

tests/test_approximators/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator):
163163
num_batches = 2
164164
data = simulator.sample((num_batches * batch_size,))
165165
return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches)
166+
167+
168+
@pytest.fixture()
169+
def mean_std_summary_network():
170+
from tests.utils import MeanStdSummaryNetwork
171+
172+
return MeanStdSummaryNetwork()
173+
174+
175+
@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"])
176+
def approximator_with_summaries(request):
177+
from bayesflow.adapters import Adapter
178+
179+
adapter = Adapter()
180+
match request.param:
181+
case "continuous_approximator":
182+
from bayesflow.approximators import ContinuousApproximator
183+
184+
return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None)
185+
case "point_approximator":
186+
from bayesflow.approximators import PointApproximator
187+
188+
return PointApproximator(adapter=adapter, inference_network=None, summary_network=None)
189+
case "model_comparison_approximator":
190+
from bayesflow.approximators import ModelComparisonApproximator
191+
192+
return ModelComparisonApproximator(
193+
num_models=2, classifier_network=None, adapter=adapter, summary_network=None
194+
)
195+
case _:
196+
raise ValueError("Invalid param for approximator class.")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from tests.utils import assert_allclose
3+
import keras
4+
5+
6+
def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
7+
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
8+
summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
9+
assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))
10+
11+
12+
def test_no_summary_network(approximator_with_summaries, monkeypatch):
13+
monkeypatch.setattr(approximator_with_summaries, "summary_network", None)
14+
15+
with pytest.raises(ValueError):
16+
approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
17+
18+
19+
def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
20+
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
21+
22+
with pytest.raises(ValueError):
23+
approximator_with_summaries.summaries({})

tests/test_diagnostics/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,17 @@ def history():
7878
}
7979

8080
return h
81+
82+
83+
@pytest.fixture()
84+
def adapter():
85+
from bayesflow.adapters import Adapter
86+
87+
return Adapter.create_default("parameters").rename("observables", "summary_variables")
88+
89+
90+
@pytest.fixture()
91+
def summary_network():
92+
from tests.utils import MeanStdSummaryNetwork
93+
94+
return MeanStdSummaryNetwork()

0 commit comments

Comments
 (0)