Skip to content

Commit 2a19d32

Browse files
rename log_gamma to calibration_log_gamma (#527)
1 parent 55d51df commit 2a19d32

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .expected_calibration_error import expected_calibration_error
55
from .classifier_two_sample_test import classifier_two_sample_test
66
from .model_misspecification import bootstrap_comparison, summary_space_comparison
7-
from .sbc import log_gamma
7+
from .calibration_log_gamma import calibration_log_gamma, gamma_null_distribution, gamma_discrepancy

bayesflow/diagnostics/metrics/sbc.py renamed to bayesflow/diagnostics/metrics/calibration_log_gamma.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ...utils.dict_utils import dicts_to_arrays
77

88

9-
def log_gamma(
9+
def calibration_log_gamma(
1010
estimates: Mapping[str, np.ndarray] | np.ndarray,
1111
targets: Mapping[str, np.ndarray] | np.ndarray,
1212
variable_keys: Sequence[str] = None,
@@ -15,7 +15,8 @@ def log_gamma(
1515
quantile: float = 0.05,
1616
):
1717
"""
18-
Compute the log gamma discrepancy statistic, see [1] for additional information.
18+
Compute the log gamma discrepancy statistic to test posterior calibration,
19+
see [1] for additional information.
1920
Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the
2021
null distribution under uniformity of ranks.
2122
That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
8585
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
8686

8787

88-
def test_log_gamma(random_estimates, random_targets):
89-
out = bf.diagnostics.metrics.log_gamma(random_estimates, random_targets)
88+
def test_calibration_log_gamma(random_estimates, random_targets):
89+
out = bf.diagnostics.metrics.calibration_log_gamma(random_estimates, random_targets)
9090
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
9191
assert out["values"].shape == (num_variables(random_estimates),)
9292
assert out["metric_name"] == "Log Gamma"
9393
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
9494

9595

96-
def test_log_gamma_end_to_end():
96+
def test_calibration_log_gamma_end_to_end():
9797
# This is a function test for simulation-based calibration.
9898
# First, we sample from a known generative process and then run SBC.
9999
# If the log gamma statistic is correctly implemented, a 95% interval should exclude
@@ -116,11 +116,11 @@ def run_sbc(N=N, S=S, D=D, bias=0):
116116
ranks = np.sum(posterior_draws < prior_draws, axis=0)
117117

118118
# this is the distribution of gamma under uniform ranks
119-
gamma_null = bf.diagnostics.metrics.sbc.gamma_null_distribution(D, S, num_null_draws=100)
119+
gamma_null = bf.diagnostics.metrics.gamma_null_distribution(D, S, num_null_draws=100)
120120
lower, upper = np.quantile(gamma_null, (0.05, 0.995))
121121

122122
# this is the empirical gamma
123-
observed_gamma = bf.diagnostics.metrics.sbc.gamma_discrepancy(ranks, num_post_draws=S)
123+
observed_gamma = bf.diagnostics.metrics.gamma_discrepancy(ranks, num_post_draws=S)
124124

125125
in_interval = lower <= observed_gamma < upper
126126

0 commit comments

Comments
 (0)