From 8f5ead820560524924f4af2e46b53c2fe0503b5e Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 5 Jan 2024 15:26:32 -0800 Subject: [PATCH] Rename global_prior arg; more tests --- tests/test_inference.py | 14 +++++++++++++- tsdate/cli.py | 6 +++--- tsdate/core.py | 37 ++++++++++++++++++++----------------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 1f30ebec..5f9182da 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -439,6 +439,18 @@ def test_custom_priors(self): priors=grid, ) + def test_prior_mixture_dim(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + prior_mixture_dim=2, + ) + def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Maximum number of EP iterations"): @@ -455,7 +467,7 @@ def test_bad_arguments(self): mutation_rate=5, population_size=1, method="variational_gamma", - global_prior=False, + prior_mixture_dim=0.1, ) def test_match_central_moments(self): diff --git a/tsdate/cli.py b/tsdate/cli.py index 54e89a96..6248f358 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -213,13 +213,13 @@ def tsdate_cli_parser(): type=int, help=( "The number of expectation-maximization iterations used to optimize the " - "global mixture prior at the end of each expectation propagation step. " + "i.i.d. mixture prior at the end of each expectation propagation step. " "Setting to zero disables optimization. Default: 10" ), default=10, ) parser.add_argument( - "--global-prior", + "--prior-mixture-dim", type=int, help=( "The number of components in the i.i.d. mixture prior for node " @@ -285,7 +285,7 @@ def run_date(args): max_shape=args.max_shape, match_central_moments=args.match_central_moments, em_iterations=args.em_iterations, - global_prior=args.global_prior, + prior_mixture_dim=args.prior_mixture_dim, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index dcf077b5..d1377f07 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1030,8 +1030,9 @@ def propagate_likelihood( min_kl, ): """ - Update approximating factors for each edge, returning average relative - difference in natural parameters (TODO) + Update approximating factors for each edge. + + TODO: return max difference in natural parameters for stopping criterion :param ndarray edges: integer array of dimension `[num_edges, 3]` containing edge id, parent id, and child id. @@ -1120,7 +1121,8 @@ def posterior_damping(x): def propagate_prior( nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol ): - """TODO + """ + Update approximating factors for global prior at each node. :param ndarray nodes: ids of nodes that should be updated :param ndarray global_prior: rows are mixture components, columns are @@ -1530,7 +1532,7 @@ def main_algorithm(self): fixed_node_set=self.get_fixed_nodes_set(), ) return ExpectationPropagation( - self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture + self.priors, lik, progress=self.pbar, global_prior=self.global_prior ) def run( @@ -1539,7 +1541,7 @@ def run( max_iterations, max_shape, match_central_moments, - global_prior, + prior_mixture_dim, em_iterations, ): if self.provenance_params is not None: @@ -1548,13 +1550,13 @@ def run( ) if not max_iterations >= 1: raise ValueError("Maximum number of EP iterations must be greater than 0") - if not (isinstance(global_prior, int) and global_prior > 0): - raise ValueError("'global_prior' must be a positive integer") + if not (isinstance(prior_mixture_dim, int) and prior_mixture_dim > 0): + raise ValueError("Number of mixture components must be a positive integer") if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - self.prior_mixture = mixture.initialize_mixture( - self.priors.grid_data, global_prior + self.global_prior = mixture.initialize_mixture( + self.priors.grid_data, prior_mixture_dim ) self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors @@ -1788,7 +1790,7 @@ def variational_gamma( max_iterations=None, max_shape=None, match_central_moments=None, - global_prior=1, + prior_mixture_dim=1, em_iterations=10, **kwargs, ): @@ -1806,8 +1808,9 @@ def variational_gamma( An i.i.d. gamma mixture is used as a prior for each node, that is initialized from the conditional coalescent and updated via expectation - maximization in each iteration. In addition, node-specific priors may be - specified via a grid of shape/rate parameters. + maximization in each iteration. If node-specific priors are supplied + (via a grid of shape/rate parameters) then these are used for + initialization. .. note:: The prior parameters for each node-to-be-dated take the form of a @@ -1830,10 +1833,10 @@ def variational_gamma( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: None, treated as False. - :param int global_prior: The number of components in the i.i.d. mixture prior + :param int prior_mixture_dim: The number of components in the i.i.d. mixture prior for node ages. Default: None, treated as 1. :param int em_iterations: The number of expectation maximization iterations used - to optimize the global mixture prior. Setting to zero disables optimization. + to optimize the i.i.d. mixture prior. Setting to zero disables optimization. Default: None, treated as 10. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. @@ -1866,8 +1869,8 @@ def variational_gamma( max_shape = 1000 if match_central_moments is None: match_central_moments = False - if global_prior is None: - global_prior = 1 + if prior_mixture_dim is None: + prior_mixture_dim = 1 if em_iterations is None: em_iterations = 10 @@ -1877,7 +1880,7 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, match_central_moments=match_central_moments, - global_prior=global_prior, + prior_mixture_dim=prior_mixture_dim, em_iterations=em_iterations, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]})