Skip to content

Commit

Permalink
Sorting out tests, prior parameterization
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Jan 5, 2024
1 parent 6c1d07b commit 3b3b6cb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
17 changes: 15 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,31 @@ def test_simple_sim_multi_tree(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma")

def test_nonglobal_priors(self):
def test_invalid_priors(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
grid.grid_data[:] = [1.0, 0.0] # noninformative prior
with pytest.raises(ValueError, match="not yet implemented"):
with pytest.raises(ValueError, match="Non-positive shape/rate"):
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
)

def test_custom_priors(self):
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
grid = priors.make_parameter_grid(population_size=1)
grid.grid_data[:] += 1.0
tsdate.date(
ts,
mutation_rate=5,
method="variational_gamma",
priors=grid,
)

def test_bad_arguments(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Maximum number of EP iterations"):
Expand All @@ -441,6 +453,7 @@ def test_bad_arguments(self):
tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
global_prior=False,
)
Expand Down
29 changes: 22 additions & 7 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,11 +1487,14 @@ class VariationalGammaMethod(EstimationMethod):

def __init__(self, ts, **kwargs):
super().__init__(ts, **kwargs)
# convert priors to natural parameterization and average
# convert priors to natural parameterization
for n in self.priors.nonfixed_nodes:
if not np.all(self.priors[n] > 0.0):
raise ValueError(
f"Non-positive shape/rate parameters for node {n}: "
f"{self.priors[n]}"
)
self.priors[n][0] -= 1.0
assert self.priors[n][0] > -1.0
assert self.priors[n][1] >= 0.0

@staticmethod
def mean_var(ts, posterior):
Expand Down Expand Up @@ -1526,9 +1529,19 @@ def main_algorithm(self):
self.recombination_rate,
fixed_node_set=self.get_fixed_nodes_set(),
)
return ExpectationPropagation(self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture)
return ExpectationPropagation(
self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture
)

def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior, em_iterations):
def run(
self,
eps,
max_iterations,
max_shape,
match_central_moments,
global_prior,
em_iterations,
):
if self.provenance_params is not None:
self.provenance_params.update(
{k: v for k, v in locals().items() if k != "self"}
Expand All @@ -1540,8 +1553,10 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio
if self.mutation_rate is None:
raise ValueError("Variational gamma method requires mutation rate")

self.prior_mixture = mixture.initialize_mixture(self.priors.grid_data, global_prior)
self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors
self.prior_mixture = mixture.initialize_mixture(
self.priors.grid_data, global_prior
)
self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors

# match sufficient statistics or match central moments
min_kl = not match_central_moments
Expand Down
14 changes: 9 additions & 5 deletions tsdate/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,20 @@ def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose)


def initialize_mixture(parameters, num_components):
"""initialize clusters by dividing nodes into equal groups"""
"""
Initialize clusters by dividing nodes into equal groups.
"parameters" are in natural parameterization (not shape/rate)
"""
global_prior = np.empty((num_components, 3))
num_nodes = parameters.shape[0]
age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[
:num_nodes
]
age_classes = np.tile(
np.arange(num_components),
num_nodes // num_components + 1,
)[:num_nodes]
for k in range(num_components):
indices = np.equal(age_classes, k)
alpha, beta = approx.average_gammas(
parameters[indices, 0] - 1.0, parameters[indices, 1]
parameters[indices, 0], parameters[indices, 1]
)
global_prior[k] = [1.0 / num_components, alpha, beta]
return global_prior

0 comments on commit 3b3b6cb

Please sign in to comment.