From 3b3b6cb0d94ee2e482f8ccfdcdbae5f246b37a3c Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 5 Jan 2024 15:11:11 -0800 Subject: [PATCH] Sorting out tests, prior parameterization --- tests/test_inference.py | 17 +++++++++++++++-- tsdate/core.py | 29 ++++++++++++++++++++++------- tsdate/mixture.py | 14 +++++++++----- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 9616dde3..1f30ebec 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -414,12 +414,12 @@ def test_simple_sim_multi_tree(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma") - def test_nonglobal_priors(self): + def test_invalid_priors(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior - with pytest.raises(ValueError, match="not yet implemented"): + with pytest.raises(ValueError, match="Non-positive shape/rate"): tsdate.date( ts, mutation_rate=5, @@ -427,6 +427,18 @@ def test_nonglobal_priors(self): priors=grid, ) + def test_custom_priors(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + grid.grid_data[:] += 1.0 + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) + def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Maximum number of EP iterations"): @@ -441,6 +453,7 @@ def test_bad_arguments(self): tsdate.date( ts, mutation_rate=5, + population_size=1, method="variational_gamma", global_prior=False, ) diff --git a/tsdate/core.py b/tsdate/core.py index 9fb42c69..dcf077b5 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1487,11 +1487,14 @@ class VariationalGammaMethod(EstimationMethod): def __init__(self, ts, **kwargs): super().__init__(ts, **kwargs) - # convert priors to natural parameterization and average + # convert priors to natural parameterization for n in self.priors.nonfixed_nodes: + if not np.all(self.priors[n] > 0.0): + raise ValueError( + f"Non-positive shape/rate parameters for node {n}: " + f"{self.priors[n]}" + ) self.priors[n][0] -= 1.0 - assert self.priors[n][0] > -1.0 - assert self.priors[n][1] >= 0.0 @staticmethod def mean_var(ts, posterior): @@ -1526,9 +1529,19 @@ def main_algorithm(self): self.recombination_rate, fixed_node_set=self.get_fixed_nodes_set(), ) - return ExpectationPropagation(self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture) + return ExpectationPropagation( + self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture + ) - def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior, em_iterations): + def run( + self, + eps, + max_iterations, + max_shape, + match_central_moments, + global_prior, + em_iterations, + ): if self.provenance_params is not None: self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} @@ -1540,8 +1553,10 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - self.prior_mixture = mixture.initialize_mixture(self.priors.grid_data, global_prior) - self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors + self.prior_mixture = mixture.initialize_mixture( + self.priors.grid_data, global_prior + ) + self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors # match sufficient statistics or match central moments min_kl = not match_central_moments diff --git a/tsdate/mixture.py b/tsdate/mixture.py index 51538b9f..0eb99687 100644 --- a/tsdate/mixture.py +++ b/tsdate/mixture.py @@ -225,16 +225,20 @@ def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose) def initialize_mixture(parameters, num_components): - """initialize clusters by dividing nodes into equal groups""" + """ + Initialize clusters by dividing nodes into equal groups. + "parameters" are in natural parameterization (not shape/rate) + """ global_prior = np.empty((num_components, 3)) num_nodes = parameters.shape[0] - age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[ - :num_nodes - ] + age_classes = np.tile( + np.arange(num_components), + num_nodes // num_components + 1, + )[:num_nodes] for k in range(num_components): indices = np.equal(age_classes, k) alpha, beta = approx.average_gammas( - parameters[indices, 0] - 1.0, parameters[indices, 1] + parameters[indices, 0], parameters[indices, 1] ) global_prior[k] = [1.0 / num_components, alpha, beta] return global_prior