Skip to content

Commit

Permalink
Rename global_prior arg; more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Jan 5, 2024
1 parent 3b3b6cb commit 8f5ead8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
14 changes: 13 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,18 @@ def test_custom_priors(self):
priors=grid,
)

def test_prior_mixture_dim(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
prior_mixture_dim=2,
)

def test_bad_arguments(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Maximum number of EP iterations"):
Expand All @@ -455,7 +467,7 @@ def test_bad_arguments(self):
mutation_rate=5,
population_size=1,
method="variational_gamma",
global_prior=False,
prior_mixture_dim=0.1,
)

def test_match_central_moments(self):
Expand Down
6 changes: 3 additions & 3 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def tsdate_cli_parser():
type=int,
help=(
"The number of expectation-maximization iterations used to optimize the "
"global mixture prior at the end of each expectation propagation step. "
"i.i.d. mixture prior at the end of each expectation propagation step. "
"Setting to zero disables optimization. Default: 10"
),
default=10,
)
parser.add_argument(
"--global-prior",
"--prior-mixture-dim",
type=int,
help=(
"The number of components in the i.i.d. mixture prior for node "
Expand Down Expand Up @@ -285,7 +285,7 @@ def run_date(args):
max_shape=args.max_shape,
match_central_moments=args.match_central_moments,
em_iterations=args.em_iterations,
global_prior=args.global_prior,
prior_mixture_dim=args.prior_mixture_dim,
)
else:
params = dict(
Expand Down
37 changes: 20 additions & 17 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,8 +1030,9 @@ def propagate_likelihood(
min_kl,
):
"""
Update approximating factors for each edge, returning average relative
difference in natural parameters (TODO)
Update approximating factors for each edge.
TODO: return max difference in natural parameters for stopping criterion
:param ndarray edges: integer array of dimension `[num_edges, 3]`
containing edge id, parent id, and child id.
Expand Down Expand Up @@ -1120,7 +1121,8 @@ def posterior_damping(x):
def propagate_prior(
nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol
):
"""TODO
"""
Update approximating factors for global prior at each node.
:param ndarray nodes: ids of nodes that should be updated
:param ndarray global_prior: rows are mixture components, columns are
Expand Down Expand Up @@ -1530,7 +1532,7 @@ def main_algorithm(self):
fixed_node_set=self.get_fixed_nodes_set(),
)
return ExpectationPropagation(
self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture
self.priors, lik, progress=self.pbar, global_prior=self.global_prior
)

def run(
Expand All @@ -1539,7 +1541,7 @@ def run(
max_iterations,
max_shape,
match_central_moments,
global_prior,
prior_mixture_dim,
em_iterations,
):
if self.provenance_params is not None:
Expand All @@ -1548,13 +1550,13 @@ def run(
)
if not max_iterations >= 1:
raise ValueError("Maximum number of EP iterations must be greater than 0")
if not (isinstance(global_prior, int) and global_prior > 0):
raise ValueError("'global_prior' must be a positive integer")
if not (isinstance(prior_mixture_dim, int) and prior_mixture_dim > 0):
raise ValueError("Number of mixture components must be a positive integer")
if self.mutation_rate is None:
raise ValueError("Variational gamma method requires mutation rate")

self.prior_mixture = mixture.initialize_mixture(
self.priors.grid_data, global_prior
self.global_prior = mixture.initialize_mixture(
self.priors.grid_data, prior_mixture_dim
)
self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors

Expand Down Expand Up @@ -1788,7 +1790,7 @@ def variational_gamma(
max_iterations=None,
max_shape=None,
match_central_moments=None,
global_prior=1,
prior_mixture_dim=1,
em_iterations=10,
**kwargs,
):
Expand All @@ -1806,8 +1808,9 @@ def variational_gamma(
An i.i.d. gamma mixture is used as a prior for each node, that is
initialized from the conditional coalescent and updated via expectation
maximization in each iteration. In addition, node-specific priors may be
specified via a grid of shape/rate parameters.
maximization in each iteration. If node-specific priors are supplied
(via a grid of shape/rate parameters) then these are used for
initialization.
.. note::
The prior parameters for each node-to-be-dated take the form of a
Expand All @@ -1830,10 +1833,10 @@ def variational_gamma(
update matches mean and variance rather than expected gamma sufficient
statistics. Faster with a similar accuracy, but does not exactly minimize
Kullback-Leibler divergence. Default: None, treated as False.
:param int global_prior: The number of components in the i.i.d. mixture prior
:param int prior_mixture_dim: The number of components in the i.i.d. mixture prior
for node ages. Default: None, treated as 1.
:param int em_iterations: The number of expectation maximization iterations used
to optimize the global mixture prior. Setting to zero disables optimization.
to optimize the i.i.d. mixture prior. Setting to zero disables optimization.
Default: None, treated as 10.
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Expand Down Expand Up @@ -1866,8 +1869,8 @@ def variational_gamma(
max_shape = 1000
if match_central_moments is None:
match_central_moments = False
if global_prior is None:
global_prior = 1
if prior_mixture_dim is None:
prior_mixture_dim = 1
if em_iterations is None:
em_iterations = 10

Expand All @@ -1877,7 +1880,7 @@ def variational_gamma(
max_iterations=max_iterations,
max_shape=max_shape,
match_central_moments=match_central_moments,
global_prior=global_prior,
prior_mixture_dim=prior_mixture_dim,
em_iterations=em_iterations,
)
return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]})
Expand Down

0 comments on commit 8f5ead8

Please sign in to comment.