Skip to content

Commit

Permalink
Backport PR gammapy#5590: Fix missing errors after using select_neste…
Browse files Browse the repository at this point in the history
…d_models
  • Loading branch information
registerrier authored and meeseeksmachine committed Nov 25, 2024
1 parent d1dc8c8 commit 9e9698d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
31 changes: 19 additions & 12 deletions gammapy/modeling/selection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
from gammapy.modeling import Fit, Parameter
from gammapy.modeling import Fit, Parameter, Covariance
from gammapy.stats.utils import sigma_to_ts
from .fit import FitResult, OptimizeResult

Expand Down Expand Up @@ -72,9 +72,9 @@ def ts_known_bkg(self, datasets):
(for example considereing diffuse background or nearby source).
"""
stat = datasets.stat_sum()
object_cache, prev_pars = self._apply_null_hypothesis(datasets)
cache = self._apply_null_hypothesis(datasets)
stat_null = datasets.stat_sum()
self._restore_status(datasets, object_cache, prev_pars)
self._restore_status(datasets, cache)
return stat_null - stat

def ts_asimov(self, datasets):
Expand Down Expand Up @@ -121,7 +121,7 @@ def run(self, datasets, apply_selection=True):
fit_results = self.fit.run(datasets)
stat = datasets.stat_sum()

object_cache, prev_pars = self._apply_null_hypothesis(datasets)
cache = self._apply_null_hypothesis(datasets)

if len(datasets.models.parameters.free_parameters) > 0:
fit_results_null = self.fit.run(datasets)
Expand All @@ -143,31 +143,38 @@ def run(self, datasets, apply_selection=True):
ts = stat_null - stat
if not apply_selection or ts > self.ts_threshold:
# restore default model if preferred against null hypothesis or if selection is ignored
self._restore_status(datasets, object_cache, prev_pars)
self._restore_status(datasets, cache)
return dict(
ts=ts,
fit_results=fit_results,
fit_results_null=fit_results_null,
)

def _apply_null_hypothesis(self, datasets):
object_cache = [p.__dict__ for p in datasets.models.parameters]
prev_pars = [p.value for p in datasets.models.parameters]
cache = dict()
cache["object"] = [p.__dict__ for p in datasets.models.parameters]
cache["values"] = [p.value for p in datasets.models.parameters]
cache["error"] = [p.error for p in datasets.models.parameters]
for p, val in zip(self.parameters, self.null_values):
if isinstance(val, Parameter):
p.__dict__ = val.__dict__
else:
p.value = val
p.frozen = True
return object_cache, prev_pars
cache["covar"] = Covariance(
datasets.models.parameters, datasets.models.covariance.data
)
return cache

def _restore_status(self, datasets, object_cache, prev_pars):
"""Restore parameters to given cached cached objects and values"""
def _restore_status(self, datasets, cache):
"""Restore parameters to given cached objects and values"""
for p in self.parameters:
p.frozen = False
for kp, p in enumerate(datasets.models.parameters):
p.__dict__ = object_cache[kp]
p.value = prev_pars[kp]
p.__dict__ = cache["object"][kp]
p.value = cache["values"][kp]
p.error = cache["error"][kp]
datasets._covariance = cache["covar"]


def select_nested_models(
Expand Down
34 changes: 28 additions & 6 deletions gammapy/modeling/tests/test_selection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import pytest
from numpy.testing import assert_allclose
from gammapy.modeling.fit import Fit
from gammapy.modeling.models import Models
from gammapy.modeling.selection import TestStatisticNested, select_nested_models
from gammapy.utils.testing import requires_data

Expand All @@ -12,18 +14,19 @@ def fermi_datasets():

filename = "$GAMMAPY_DATA/fermi-3fhl-crab/Fermi-LAT-3FHL_datasets.yaml"
filename_models = "$GAMMAPY_DATA/fermi-3fhl-crab/Fermi-LAT-3FHL_models.yaml"
return Datasets.read(filename=filename, filename_models=filename_models)
fermi_datasets = Datasets.read(filename=filename, filename_models=filename_models)
return fermi_datasets


@requires_data()
def test_test_statistic_detection(fermi_datasets):

model = fermi_datasets.models["Crab Nebula"]

results = select_nested_models(
fermi_datasets, [model.spectral_model.amplitude], [0]
)
assert_allclose(results["ts"], 20905.667798, rtol=1e-5)
assert fermi_datasets.models.parameters["amplitude"].error != 0.0

ts_eval = TestStatisticNested([model.spectral_model.amplitude], [0])
ts_known_bkg = ts_eval.ts_known_bkg(fermi_datasets)
Expand All @@ -48,14 +51,15 @@ def test_test_statistic_detection(fermi_datasets):

@requires_data()
def test_test_statistic_detection_other_frozen(fermi_datasets):

with fermi_datasets.models.restore_status():
fermi_datasets.models.freeze()
model = fermi_datasets.models["Crab Nebula"]
results = select_nested_models(
fermi_datasets, [model.spectral_model.amplitude], [0]
)
results["fit_results_null"].nfev == 0
assert fermi_datasets.models.parameters["amplitude"].error != 0.0

model.spectral_model.amplitude.value = 0
assert_allclose(
results["fit_results_null"].parameters.value,
Expand All @@ -65,12 +69,15 @@ def test_test_statistic_detection_other_frozen(fermi_datasets):

@requires_data()
def test_test_statistic_link(fermi_datasets):

# TODO: better test with simulated data ?
model = fermi_datasets.models["Crab Nebula"]

models = Models.read("$GAMMAPY_DATA/fermi-3fhl-crab/Fermi-LAT-3FHL_models.yaml")

model = models["Crab Nebula"]
model2 = model.copy(name="other")
model2.spectral_model.alpha.value = 2.4
fermi_datasets.models = fermi_datasets.models + [model2]

fermi_datasets.models = models + [model2]

fit = Fit()
minuit_opts = {"tol": 10, "strategy": 0}
Expand All @@ -82,5 +89,20 @@ def test_test_statistic_link(fermi_datasets):
)
results = ts_eval.run(fermi_datasets)

assert results["ts"] > ts_eval.ts_threshold
assert model2.spectral_model.alpha.value != model.spectral_model.alpha.value
assert model2.spectral_model.alpha.error != model.spectral_model.alpha.error
assert model2.spectral_model.alpha.error != 0

ts_eval = TestStatisticNested(
[model.spectral_model.alpha],
[model2.spectral_model.alpha],
fit=fit,
n_sigma=np.inf,
)
results = ts_eval.run(fermi_datasets)

assert results["ts"] < ts_eval.ts_threshold
assert_allclose(model2.spectral_model.alpha.value, model.spectral_model.alpha.value)
assert_allclose(model2.spectral_model.alpha.error, model.spectral_model.alpha.error)
assert model2.spectral_model.alpha.error != 0

0 comments on commit 9e9698d

Please sign in to comment.