Skip to content

Commit

Permalink
Add check for unary nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Jun 7, 2024
1 parent 575e194 commit 4500dcf
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,38 @@ def test_sim_example(self):
assert ts.num_trees == num_trees + first_empty + last_empty

# TODO - test minimum_gap param


class TestUnaryNodeCheck:
def test_inferred(self):
ts = msprime.sim_ancestry(
10,
population_size=1e4,
recombination_rate=1e-8,
sequence_length=1e6,
random_seed=1,
)
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(sample_data)
simplified_ts = inferred_ts.simplify()
assert tsdate.util.contains_unary_nodes(inferred_ts)
assert not tsdate.util.contains_unary_nodes(simplified_ts)
with pytest.raises(ValueError, match="contains unary nodes"):
tsdate.date(inferred_ts, mutation_rate=1e-8, method="variational_gamma")

def test_simulated(self):
ts = msprime.sim_ancestry(
10,
population_size=1e4,
recombination_rate=1e-8,
sequence_length=1e6,
random_seed=1,
record_full_arg=True,
)
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
simplified_ts = ts.simplify()
assert tsdate.util.contains_unary_nodes(ts)
assert not tsdate.util.contains_unary_nodes(simplified_ts)
with pytest.raises(ValueError, match="contains unary nodes"):
tsdate.date(ts, mutation_rate=1e-8, method="variational_gamma")
68 changes: 68 additions & 0 deletions tsdate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import tsdate
from . import provenance
from .approx import _b
from .approx import _b1r
from .approx import _f
from .approx import _f1r
Expand Down Expand Up @@ -706,3 +707,70 @@ def constrain_mutations(ts, nodes_time, mutations_edge):
logging.info(f"Set ages of {external} nonsegregating mutations to root times.")

return constrained_time


@numba.njit(_b(_i1r, _f1r, _f1r, _i1r, _i1r, _f, _i))
def _contains_unary_nodes(
edges_parent,
edges_left,
edges_right,
indexes_insert,
indexes_remove,
sequence_length,
num_nodes,
):
assert edges_parent.size == edges_left.size == edges_right.size
assert indexes_insert.size == indexes_remove.size == edges_parent.size

num_edges = edges_parent.size
nodes_children = np.zeros(num_nodes, dtype=np.int32)
position_insert = edges_left[indexes_insert]
position_remove = edges_right[indexes_remove]

left = 0.0
a, b = 0, 0
while a < num_edges or b < num_edges:
check = set()

while b < num_edges and position_remove[b] == left: # edges out
e = indexes_remove[b]
p = edges_parent[e]
nodes_children[p] -= 1
check.add(p)
b += 1

while a < num_edges and position_insert[a] == left: # edges in
e = indexes_insert[a]
p = edges_parent[e]
nodes_children[p] += 1
check.add(p)
a += 1

for p in check:
if nodes_children[p] == 1:
return True

right = sequence_length
if b < num_edges:
right = min(right, position_remove[b])
if a < num_edges:
right = min(right, position_insert[a])
left = right

return False


def contains_unary_nodes(ts):
"""
Check if any node in the tree sequence is unary over some portion of its span
"""

return _contains_unary_nodes(
ts.edges_parent,
ts.edges_left,
ts.edges_right,
ts.indexes_edge_insertion_order,
ts.indexes_edge_removal_order,
ts.sequence_length,
ts.num_nodes,
)
5 changes: 5 additions & 0 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .rescaling import edge_sampling_weight
from .rescaling import mutational_timescale
from .rescaling import piecewise_scale_posterior
from .util import contains_unary_nodes


# columns for edge_factors
Expand Down Expand Up @@ -154,6 +155,10 @@ def _check_valid_constraints(constraints, edges_parent, edges_child):

@staticmethod
def _check_valid_inputs(ts, likelihoods, constraints, mutations_edge):
if contains_unary_nodes(ts):
raise ValueError(
"Tree sequence contains unary nodes, simplify before dating"
)
if likelihoods.shape != (ts.num_edges, 2):
raise ValueError("Edge likelihoods are the wrong shape")
if constraints.shape != (ts.num_nodes, 2):
Expand Down

0 comments on commit 4500dcf

Please sign in to comment.