Skip to content

Allow historical samples in inside_outside method #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
[0.1.6] - ****-**-**
--------------------

**Features**

- Historical samples can now be incorporated directly into the dating framework.
This is done by constructing a bespoke prior grid using
``grid=tsdate.build_prior_grid(..., allow_historical_samples=True`` and
passing that into ``tsdate.date``. It is also possible to set a variance for
historial sample nodes.

**Breaking changes**

- The standalone ``preprocess_ts`` function now defaults to not removing unreferenced
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tskit>=0.4.0
tskit>=0.5.2
tsinfer>=0.3.0
flake8
numpy
Expand Down
17 changes: 3 additions & 14 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def test_dangling_fails(self):
print(ts.draw_text())
print("Samples:", ts.samples())
Ne = 0.5
with pytest.raises(ValueError, match="simplified"):
with pytest.raises(ValueError, match="simplify"):
tsdate.build_prior_grid(ts, Ne, timepoints=np.array([0, 1.2, 2]))
# mut_rate = 1
# eps = 1e-6
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def test_date_input(self):

def test_sample_as_parent_fails(self):
ts = utility_functions.single_tree_ts_n3_sample_as_parent()
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.date(ts, mutation_rate=None, Ne=1)

def test_recombination_not_implemented(self):
Expand Down Expand Up @@ -1532,18 +1532,7 @@ def test_constrain_ages_topo(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = np.array([3, 4, 5])
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)

def test_constrain_ages_topo_no_nodes_to_date(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = None
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)
Expand Down
19 changes: 16 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_bad_Ne(self):

def test_dangling_failure(self):
ts = utility_functions.single_tree_ts_n2_dangling()
with pytest.raises(ValueError, match="simplified"):
with pytest.raises(ValueError, match="simplify"):
tsdate.date(ts, mutation_rate=None, Ne=1)

def test_unary_failure(self):
Expand Down Expand Up @@ -271,16 +271,29 @@ def test_fails_multi_root(self):
with pytest.raises(ValueError):
tsdate.date(multiroot_ts, Ne=1, mutation_rate=2, priors=good_priors)

def test_non_contemporaneous(self):
def test_non_contemporaneous_warn(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.date(ts, Ne=1, mutation_rate=2)
with pytest.raises(ValueError, match="samples at non-zero times"):
tsdate.build_prior_grid(ts, Ne=1)

def test_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0),
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
priors = tsdate.build_prior_grid(ts, Ne=1, allow_historical_samples=True)
tsdate.date(ts, priors=priors, mutation_rate=2)

def test_no_mutation_times(self):
ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12)
Expand Down
16 changes: 12 additions & 4 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def __init__(
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
self.probability_space = LIN

def fixed_node_ids(self):
return np.where(self.row_lookup < 0)[0]

def nonfixed_node_ids(self):
return np.where(self.row_lookup >= 0)[0]

def force_probability_space(self, probability_space):
"""
probability_space can be "logarithmic" or "linear": this function will force
Expand Down Expand Up @@ -140,6 +146,9 @@ def normalize(self):
else:
raise RuntimeError("Probability space is not", LIN, "or", LOG)

def is_fixed(self, node_id):
return self.row_lookup[node_id] < 0

def __getitem__(self, node_id):
index = self.row_lookup[node_id]
if index < 0:
Expand Down Expand Up @@ -207,8 +216,7 @@ def fill_fixed(orig, fixed_data):
new_obj.fixed_data = fill_fixed(
self, grid_data if fixed_data is None else fixed_data
)
if probability_space is None:
new_obj.probability_space = self.probability_space
else:
new_obj.probability_space = probability_space
new_obj.probability_space = self.probability_space
if probability_space is not None:
new_obj.force_probability_space(probability_space)
return new_obj
119 changes: 73 additions & 46 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
"""
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
if normalize:
return ll / np.max(ll)
return ll / np.nanmax(ll)
else:
return ll

Expand Down Expand Up @@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):

mutations_on_edge = self.mut_edges[edge.id]
child_time = self.ts.node(edge.child).time
assert child_time == 0
# Temporary hack - we should really take a more precise likelihood
return self._lik(
mutations_on_edge,
edge.span,
self.timediff,
self.mut_rate,
normalize=self.normalize,
)
if child_time == 0:
return self._lik(
mutations_on_edge,
edge.span,
self.timediff,
self.mut_rate,
normalize=self.normalize,
)
else:
timediff = self.timepoints - child_time + 1e-8
# Temporary hack - we should really take a more precise likelihood
likelihood = self._lik(
mutations_on_edge,
edge.span,
timediff,
self.mut_rate,
normalize=self.normalize,
)
# Prevent child from being older than parent
likelihood[timediff < 0] = 0

return likelihood

def get_mut_lik_lower_tri(self, edge):
"""
Expand Down Expand Up @@ -389,7 +402,7 @@ def get_fixed(self, arr, edge):
return arr * liks

def scale_geometric(self, fraction, value):
return value**fraction
return value ** fraction


class LogLikelihoods(Likelihoods):
Expand Down Expand Up @@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
"""
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
if normalize:
return ll - np.max(ll)
return ll - np.nanmax(ll)
else:
return ll

Expand Down Expand Up @@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
inside = self.priors.clone_with_new_data( # store inside matrix values
grid_data=np.nan, fixed_data=self.lik.identity_constant
)
# It is possible that a simple node is non-fixed, in which case we want to
# provide an inside array that reflects the prior distribution
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
for u in nonfixed_samples:
# this is in the same probability space as the prior, so we should be
# OK just to copy the prior values straight in. It's unclear to me (Yan)
# how/if they should be normalised, however
inside[u][:] = self.priors[u]

if cache_inside:
g_i = np.full(
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
)
norm = np.full(self.ts.num_nodes, np.nan)
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
to_visit[inside.nonfixed_node_ids()] = True
# Iterate through the nodes via groupby on parent node
for parent, edges in tqdm(
self.edges_by_parent_asc(),
Expand Down Expand Up @@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
"dangling nodes: please simplify it"
)
daughter_val = self.lik.scale_geometric(
spanfrac, self.lik.make_lower_tri(inside[edge.child])
spanfrac, self.lik.make_lower_tri(inside_values)
)
edge_lik = self.lik.get_inside(daughter_val, edge)
val = self.lik.combine(val, edge_lik)
if np.all(val == 0):
raise ValueError
if cache_inside:
g_i[edge.id] = edge_lik
norm[parent] = np.max(val) if normalize else 1
norm[parent] = np.max(val) if normalize else self.lik.identity_constant
inside[parent] = self.lik.reduce(val, norm[parent])
to_visit[parent] = False

# There may be nodes that are not parents but are also not fixed (e.g.
# undated sample nodes). These need an identity normalization constant
for unfixed_unvisited in np.where(to_visit)[0]:
norm[unfixed_unvisited] = self.lik.identity_constant

if cache_inside:
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
# Keep the results in this object
Expand Down Expand Up @@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
return ts, mn_post, vr_post


def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
def constrain_ages_topo(ts, node_times, eps, progress=False):
"""
If predicted node times violate topology, restrict node ages so that they
must be older than all their children.
If node_times violate topology, return increased node_times so that each node is
guaranteed to be older than any of its their children.
"""
new_mn_post = np.copy(post_mn)
if nodes_to_date is None:
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]

tables = ts.tables
parents = tables.edges.parent
nd_children = tables.edges.child[np.argsort(parents)]
parents = sorted(parents)
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
for index, nd in tqdm(
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
edges_parent = ts.edges_parent
edges_child = ts.edges_child

new_node_times = np.copy(node_times)
# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
for edges_start, edges_end in tqdm(
zip(
itertools.chain([0], new_parent_edge_idx),
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
),
desc="Constrain Ages",
disable=not progress,
):
if index + 1 != len(nodes_to_date):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
else:
children_index = np.arange(parent_indices[index], ts.num_edges)
children = nd_children[children_index]
time = np.max(new_mn_post[children])
if new_mn_post[nd] <= time:
new_mn_post[nd] = time + eps
return new_mn_post
parent = edges_parent[edges_start]
child_ids = edges_child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(new_node_times[child_ids])
if oldest_child_time >= new_node_times[parent]:
new_node_times[parent] = oldest_child_time + eps
return new_node_times


def date(
Expand Down Expand Up @@ -1015,7 +1046,7 @@ def date(
progress=progress,
**kwargs
)
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
constrained = constrain_ages_topo(tree_sequence, dates, eps, progress)
tables = tree_sequence.dump_tables()
tables.time_units = time_units
tables.nodes.time = constrained
Expand Down Expand Up @@ -1064,12 +1095,6 @@ def get_dates(

:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
"""
# Stuff yet to be implemented. These can be deleted once fixed
for sample in tree_sequence.samples():
if tree_sequence.node(sample).time != 0:
raise NotImplementedError("Samples must all be at time 0")
fixed_nodes = set(tree_sequence.samples())

# Default to not creating approximate priors unless ts has > 1000 samples
approx_priors = False
if tree_sequence.num_samples > 1000:
Expand Down Expand Up @@ -1097,6 +1122,8 @@ def get_dates(
)
priors = priors

fixed_nodes = set(priors.fixed_node_ids())

if probability_space != base.LOG:
liklhd = Likelihoods(
tree_sequence,
Expand Down
Loading