diff --git a/tests/test_util.py b/tests/test_util.py index 6eb626fe..dcf2c35b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -312,3 +312,38 @@ def test_sim_example(self): assert ts.num_trees == num_trees + first_empty + last_empty # TODO - test minimum_gap param + + +class TestUnaryNodeCheck: + def test_inferred(self): + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + inferred_ts = tsinfer.infer(sample_data) + simplified_ts = inferred_ts.simplify() + assert tsdate.util.contains_unary_nodes(inferred_ts) + assert not tsdate.util.contains_unary_nodes(simplified_ts) + with pytest.raises(ValueError, match="contains unary nodes"): + tsdate.date(inferred_ts, mutation_rate=1e-8, method="variational_gamma") + + def test_simulated(self): + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + record_full_arg=True, + ) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + simplified_ts = ts.simplify() + assert tsdate.util.contains_unary_nodes(ts) + assert not tsdate.util.contains_unary_nodes(simplified_ts) + with pytest.raises(ValueError, match="contains unary nodes"): + tsdate.date(ts, mutation_rate=1e-8, method="variational_gamma") diff --git a/tsdate/util.py b/tsdate/util.py index e47f397e..1b122f28 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -33,6 +33,7 @@ import tsdate from . import provenance +from .approx import _b from .approx import _b1r from .approx import _f from .approx import _f1r @@ -706,3 +707,70 @@ def constrain_mutations(ts, nodes_time, mutations_edge): logging.info(f"Set ages of {external} nonsegregating mutations to root times.") return constrained_time + + +@numba.njit(_b(_i1r, _f1r, _f1r, _i1r, _i1r, _f, _i)) +def _contains_unary_nodes( + edges_parent, + edges_left, + edges_right, + indexes_insert, + indexes_remove, + sequence_length, + num_nodes, +): + assert edges_parent.size == edges_left.size == edges_right.size + assert indexes_insert.size == indexes_remove.size == edges_parent.size + + num_edges = edges_parent.size + nodes_children = np.zeros(num_nodes, dtype=np.int32) + position_insert = edges_left[indexes_insert] + position_remove = edges_right[indexes_remove] + + left = 0.0 + a, b = 0, 0 + while a < num_edges or b < num_edges: + check = set() + + while b < num_edges and position_remove[b] == left: # edges out + e = indexes_remove[b] + p = edges_parent[e] + nodes_children[p] -= 1 + check.add(p) + b += 1 + + while a < num_edges and position_insert[a] == left: # edges in + e = indexes_insert[a] + p = edges_parent[e] + nodes_children[p] += 1 + check.add(p) + a += 1 + + for p in check: + if nodes_children[p] == 1: + return True + + right = sequence_length + if b < num_edges: + right = min(right, position_remove[b]) + if a < num_edges: + right = min(right, position_insert[a]) + left = right + + return False + + +def contains_unary_nodes(ts): + """ + Check if any node in the tree sequence is unary over some portion of its span + """ + + return _contains_unary_nodes( + ts.edges_parent, + ts.edges_left, + ts.edges_right, + ts.indexes_edge_insertion_order, + ts.indexes_edge_removal_order, + ts.sequence_length, + ts.num_nodes, + ) diff --git a/tsdate/variational.py b/tsdate/variational.py index f1a6b197..aa11e4e5 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -48,6 +48,7 @@ from .rescaling import edge_sampling_weight from .rescaling import mutational_timescale from .rescaling import piecewise_scale_posterior +from .util import contains_unary_nodes # columns for edge_factors @@ -154,6 +155,10 @@ def _check_valid_constraints(constraints, edges_parent, edges_child): @staticmethod def _check_valid_inputs(ts, likelihoods, constraints, mutations_edge): + if contains_unary_nodes(ts): + raise ValueError( + "Tree sequence contains unary nodes, simplify before dating" + ) if likelihoods.shape != (ts.num_edges, 2): raise ValueError("Edge likelihoods are the wrong shape") if constraints.shape != (ts.num_nodes, 2):