Skip to content

Fixup num_tracked_samples to work with virtual_root #1861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -4639,6 +4639,89 @@ test_single_tree_map_mutations_internal_samples(void)
tsk_tree_free(&t);
}

static void
test_single_tree_tracked_samples(void)
{
tsk_treeseq_t ts;
tsk_tree_t tree;
tsk_id_t samples[] = { 0, 1 };
tsk_size_t n;
int ret;

tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL,
single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL, 0);

ret = tsk_tree_init(&tree, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_tree_set_tracked_samples(&tree, 2, samples);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 1);
ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);

ret = tsk_tree_first(&tree);
CU_ASSERT_EQUAL_FATAL(ret, 1);

ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 1);
ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);
ret = tsk_tree_get_num_tracked_samples(&tree, 5, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, 6, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);
ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);

ret = tsk_tree_next(&tree);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 1);
ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);

ret = tsk_tree_next(&tree);
CU_ASSERT_EQUAL_FATAL(ret, 1);
ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 1);
ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);
ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 2);

ret = tsk_tree_set_tracked_samples(&tree, 0, NULL);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 0);
ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(n, 0);

tsk_treeseq_free(&ts);
tsk_tree_free(&tree);
}

/*=======================================================
* Multi tree tests.
*======================================================*/
Expand Down Expand Up @@ -6746,6 +6829,7 @@ main(int argc, char **argv)
{ "test_single_tree_map_mutations", test_single_tree_map_mutations },
{ "test_single_tree_map_mutations_internal_samples",
test_single_tree_map_mutations_internal_samples },
{ "test_single_tree_tracked_samples", test_single_tree_tracked_samples },

/* Multi tree tests */
{ "test_simple_multi_tree", test_simple_multi_tree },
Expand Down
7 changes: 5 additions & 2 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -3314,7 +3314,7 @@ tsk_tree_reset_tracked_samples(tsk_tree_t *self)
goto out;
}
tsk_memset(self->num_tracked_samples, 0,
self->num_nodes * sizeof(*self->num_tracked_samples));
(self->num_nodes + 1) * sizeof(*self->num_tracked_samples));
out:
return ret;
}
Expand Down Expand Up @@ -3367,6 +3367,7 @@ tsk_tree_set_tracked_samples_from_sample_list(
tsk_id_t u, stop, index;
const tsk_id_t *next = other->next_sample;
const tsk_id_t *samples = other->tree_sequence->samples;
tsk_size_t num_tracked_samples = 0;

if (!tsk_tree_has_sample_lists(other)) {
ret = TSK_ERR_UNSUPPORTED_OPERATION;
Expand All @@ -3385,6 +3386,7 @@ tsk_tree_set_tracked_samples_from_sample_list(
stop = other->right_sample[node];
while (true) {
u = samples[index];
num_tracked_samples++;
tsk_bug_assert(self->num_tracked_samples[u] == 0);
/* Propagate this upwards */
while (u != TSK_NULL) {
Expand All @@ -3397,6 +3399,7 @@ tsk_tree_set_tracked_samples_from_sample_list(
index = next[index];
}
}
self->num_tracked_samples[self->virtual_root] = num_tracked_samples;
out:
return ret;
}
Expand Down Expand Up @@ -4293,7 +4296,7 @@ tsk_tree_clear(tsk_tree_t *self)
self->num_tracked_samples[j] = 0;
}
}
self->num_tracked_samples[self->virtual_root] = 0;
/* The total tracked_samples gets set in set_tracked_samples */
self->num_samples[self->virtual_root] = num_samples;
}
if (sample_lists) {
Expand Down
33 changes: 13 additions & 20 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,39 +1464,32 @@ def test_compute_mutation_time(self):
# Check we have valid times
tables.tree_sequence()

def verify_tracked_samples(self, ts):
@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_tracked_samples(self, ts):
# Should be empty list by default.
for tree in ts.trees():
assert tree.get_num_tracked_samples() == 0
assert tree.num_tracked_samples() == 0
for u in tree.nodes():
assert tree.get_num_tracked_samples(u) == 0
assert tree.num_tracked_samples(u) == 0
samples = list(ts.samples())
tracked_samples = samples[:2]
for tree in ts.trees(tracked_samples=tracked_samples):
if len(tree.parent_dict) == 0:
# This is a crude way of checking if we have multiple roots.
# We'll need to fix this code up properly when we support multiple
# roots and remove this check
break
nu = [0 for j in range(ts.get_num_nodes())]
assert tree.get_num_tracked_samples() == len(tracked_samples)
nu = [0 for j in range(ts.num_nodes)]
assert tree.num_tracked_samples() == len(tracked_samples)
for j in tracked_samples:
u = j
while u != tskit.NULL:
nu[u] += 1
u = tree.get_parent(u)
u = tree.parent(u)
for u, count in enumerate(nu):
assert tree.get_num_tracked_samples(u) == count

def test_tracked_samples(self):
for ts in get_example_tree_sequences():
self.verify_tracked_samples(ts)
assert tree.num_tracked_samples(u) == count
assert tree.num_tracked_samples(tree.virtual_root) == len(tracked_samples)

def test_tracked_samples_is_first_arg(self):
for ts in get_example_tree_sequences():
samples = list(ts.samples())[:2]
for a, b in zip(ts.trees(samples), ts.trees(tracked_samples=samples)):
assert a.get_num_tracked_samples() == b.get_num_tracked_samples()
ts = tskit.Tree.generate_balanced(6).tree_sequence
samples = [0, 1, 2]
tree = next(ts.trees(samples))
assert tree.num_tracked_samples() == 3

def test_deprecated_sample_aliases(self):
for ts in get_example_tree_sequences():
Expand Down
10 changes: 2 additions & 8 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,14 +2106,8 @@ def num_tracked_samples(self, u=None):
the subtree rooted at u.
:rtype: int
"""
# This should work, there's a but somethings wrong somewhere
# https://github.com/tskit-dev/tskit/issues/1724
# u = self.virtual_root if u is None else u
# return self._ll_tree.get_num_tracked_samples(u)
roots = [u]
if u is None:
roots = self.roots
return sum(self._ll_tree.get_num_tracked_samples(root) for root in roots)
u = self.virtual_root if u is None else u
return self._ll_tree.get_num_tracked_samples(u)

# TODO document these traversal arrays
# https://github.com/tskit-dev/tskit/issues/1788
Expand Down