Skip to content

Commit 365abef

Browse files
Work on traversal orders in Python
1 parent 0f94cc8 commit 365abef

File tree

2 files changed

+130
-143
lines changed

2 files changed

+130
-143
lines changed

python/tests/__init__.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -83,80 +83,6 @@ def children(self, u):
8383
v = self.right_sib[v]
8484
return ret
8585

86-
def _preorder_nodes(self, u, node_ist):
87-
node_ist.append(u)
88-
for c in self.children(u):
89-
self._preorder_nodes(c, node_ist)
90-
91-
def _postorder_nodes(self, u, node_ist):
92-
for c in self.children(u):
93-
self._postorder_nodes(c, node_ist)
94-
node_ist.append(u)
95-
96-
def _inorder_nodes(self, u, node_ist):
97-
children = self.children(u)
98-
if len(children) > 0:
99-
mid = len(children) // 2
100-
for v in children[:mid]:
101-
self._inorder_nodes(v, node_ist)
102-
node_ist.append(u)
103-
for v in children[mid:]:
104-
self._inorder_nodes(v, node_ist)
105-
else:
106-
node_ist.append(u)
107-
108-
def _levelorder_nodes(self, u, node_ist, level):
109-
node_ist[level].append(u) if level < len(node_ist) else node_ist.append([u])
110-
for c in self.children(u):
111-
self._levelorder_nodes(c, node_ist, level + 1)
112-
113-
def _minlex_postorder_nodes(self, u, node_ist):
114-
node_ist.extend(self._minlex_postorder_nodes_helper(u)[1])
115-
116-
def _minlex_postorder_nodes_helper(self, u):
117-
"""
118-
For a given input ID u, this function returns a tuple whose first value
119-
is the minimum leaf node ID under node u, and whose second value is
120-
a list containing the minlex postorder for the subtree rooted at node u.
121-
The first value is needed for sorting, and the second value is what
122-
finally gets returned.
123-
"""
124-
children = self.children(u)
125-
if len(children) > 0:
126-
children_return = [self._minlex_postorder_nodes_helper(c) for c in children]
127-
# sorts by first value, which is the minimum leaf node ID
128-
children_return.sort()
129-
minlex_postorder = []
130-
for _, child_minlex_postorder in children_return:
131-
minlex_postorder.extend(child_minlex_postorder)
132-
minlex_postorder.extend([u])
133-
return (children_return[0][0], minlex_postorder)
134-
else:
135-
return (u, [u])
136-
137-
def nodes(self, root=None, order="preorder"):
138-
roots = [root]
139-
if root is None:
140-
roots = self.roots
141-
for u in roots:
142-
node_list = []
143-
if order == "preorder":
144-
self._preorder_nodes(u, node_list)
145-
elif order == "inorder":
146-
self._inorder_nodes(u, node_list)
147-
elif order == "postorder":
148-
self._postorder_nodes(u, node_list)
149-
elif order == "levelorder" or order == "breadthfirst":
150-
# Returns nodes in their respective levels
151-
# Nested list comprehension flattens node_list in order
152-
self._levelorder_nodes(u, node_list, 0)
153-
node_list = iter([i for level in node_list for i in level])
154-
elif order == "minlex_postorder":
155-
self._minlex_postorder_nodes(u, node_list)
156-
else:
157-
raise ValueError("order not supported")
158-
yield from node_list
159-
16086
def get_interval(self):
16187
return self.left, self.right
16288

python/tests/test_highlevel.py

Lines changed: 130 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,103 @@
5858
from tskit import UNKNOWN_TIME
5959

6060

61+
def traversal_preorder(tree, root=None):
62+
roots = tree.roots if root is None else [root]
63+
for node in roots:
64+
yield node
65+
for child in tree.children(node):
66+
yield from traversal_preorder(tree, child)
67+
68+
69+
def traversal_postorder(tree, root=None):
70+
roots = tree.roots if root is None else [root]
71+
for node in roots:
72+
for child in tree.children(node):
73+
yield from traversal_postorder(tree, child)
74+
yield node
75+
76+
77+
def traversal_inorder(tree, root=None):
78+
roots = tree.roots if root is None else [root]
79+
for node in roots:
80+
children = list(tree.children(node))
81+
half = len(children) // 2
82+
for child in children[:half]:
83+
yield from traversal_inorder(tree, child)
84+
yield node
85+
for child in children[half:]:
86+
yield from traversal_inorder(tree, child)
87+
88+
89+
def traversal_levelorder(tree, root=None):
90+
roots = tree.roots if root is None else [root]
91+
for node in roots:
92+
# Sorted is a stable ordering so nodes are in tree order within
93+
# the level
94+
yield from sorted(list(tree.nodes(node)), key=lambda u: tree.depth(u))
95+
96+
97+
def _traversal_minlex_postorder(tree, u):
98+
"""
99+
For a given input ID u, this function returns a tuple whose first value
100+
is the minimum leaf node ID under node u, and whose second value is
101+
a list containing the minlex postorder for the subtree rooted at node u.
102+
The first value is needed for sorting, and the second value is what
103+
finally gets returned.
104+
"""
105+
children = tree.children(u)
106+
if len(children) > 0:
107+
children_return = [_traversal_minlex_postorder(tree, c) for c in children]
108+
# sorts by first value, which is the minimum leaf node ID
109+
children_return.sort(key=lambda x: x[0])
110+
minlex_postorder = []
111+
for _, child_minlex_postorder in children_return:
112+
minlex_postorder.extend(child_minlex_postorder)
113+
minlex_postorder.extend([u])
114+
return (children_return[0][0], minlex_postorder)
115+
else:
116+
return (u, [u])
117+
118+
119+
def traversal_minlex_postorder(tree, root=None):
120+
roots = tree.roots if root is None else [root]
121+
root_lists = [_traversal_minlex_postorder(tree, node) for node in roots]
122+
for _, node_list in sorted(root_lists, key=lambda x: x[0]):
123+
yield from node_list
124+
125+
126+
def traversal_timeasc(tree, root=None):
127+
# This is wrong, it should be
128+
# nodes = sorted(tree.nodes(root), key=lambda u: (tree.time(u), u))
129+
# https://github.com/tskit-dev/tskit/issues/1776
130+
roots = tree.roots if root is None else [root]
131+
for node in roots:
132+
nodes = sorted(tree.nodes(node), key=lambda u: (tree.time(u), u))
133+
yield from nodes
134+
135+
136+
def traversal_timedesc(tree, root=None):
137+
# This is wrong, it should be
138+
# nodes = sorted(tree.nodes(root), key=lambda u: (-tree.time(u), u))
139+
# https://github.com/tskit-dev/tskit/issues/1776
140+
roots = tree.roots if root is None else [root]
141+
for node in roots:
142+
nodes = sorted(tree.nodes(node), key=lambda u: (tree.time(u), u), reverse=True)
143+
yield from nodes
144+
145+
146+
traversal_map = {
147+
"preorder": traversal_preorder,
148+
"postorder": traversal_postorder,
149+
"inorder": traversal_inorder,
150+
"levelorder": traversal_levelorder,
151+
"breadthfirst": traversal_levelorder,
152+
"minlex_postorder": traversal_minlex_postorder,
153+
"timeasc": traversal_timeasc,
154+
"timedesc": traversal_timedesc,
155+
}
156+
157+
61158
def insert_uniform_mutations(tables, num_mutations, nodes):
62159
"""
63160
Returns n evenly mutations over the specified list of nodes.
@@ -363,6 +460,39 @@ def get_samples(ts, time=None, population=None):
363460
return np.array(samples)
364461

365462

463+
class TestTreeTraversals:
464+
@pytest.mark.parametrize("ts", get_example_tree_sequences())
465+
@pytest.mark.parametrize(
466+
"order",
467+
[
468+
"preorder",
469+
# "postorder",
470+
# "inorder",
471+
# "levelorder",
472+
# "breadthfirst",
473+
# "minlex_postorder",
474+
],
475+
)
476+
def test_traversals_virtual_root(self, ts, order):
477+
tree = ts.first()
478+
node_list2 = list(traversal_map[order](tree, tree.virtual_root))
479+
node_list1 = list(tree.nodes(tree.virtual_root, order=order))
480+
assert tree.virtual_root in node_list1
481+
assert node_list1 == node_list2
482+
483+
@pytest.mark.parametrize("ts", get_example_tree_sequences())
484+
@pytest.mark.parametrize("order", list(traversal_map.keys()))
485+
def test_traversals(self, ts, order):
486+
tree = next(ts.trees())
487+
traverser = traversal_map[order]
488+
node_list1 = list(tree.nodes(order=order))
489+
node_list2 = list(traverser(tree))
490+
assert node_list1 == node_list2
491+
492+
# TODO add tests for traversing from specific nodes and a couple
493+
# of hand crafted examples just for sanity checking.
494+
495+
366496
class TestMRCACalculator:
367497
"""
368498
Class to test the Schieber-Vishkin algorithm.
@@ -2649,75 +2779,6 @@ def verify_nx_nearest_neighbor_search(self):
26492779
nearest_neighbor_of = [min(dist_dod[u], key=dist_dod[u].get) for u in range(3)]
26502780
assert [2, 2, 1] == [nearest_neighbor_of[u] for u in range(3)]
26512781

2652-
def test_traversals(self):
2653-
for ts in get_example_tree_sequences():
2654-
tree = next(ts.trees())
2655-
self.verify_traversals(tree)
2656-
2657-
# Verify time-ordered traversals separately, because the PythonTree
2658-
# class does not contain time information at the moment
2659-
for root in tree.roots:
2660-
time_ordered = tree.nodes(root, order="timeasc")
2661-
t = tree.time(next(time_ordered))
2662-
for u in time_ordered:
2663-
next_t = tree.time(u)
2664-
assert next_t >= t
2665-
t = next_t
2666-
time_ordered = tree.nodes(root, order="timedesc")
2667-
t = tree.time(next(time_ordered))
2668-
for u in time_ordered:
2669-
next_t = tree.time(u)
2670-
assert next_t <= t
2671-
t = next_t
2672-
2673-
def verify_traversals(self, tree):
2674-
t1 = tree
2675-
t2 = tests.PythonTree.from_tree(t1)
2676-
assert list(t1.nodes()) == list(t2.nodes())
2677-
orders = [
2678-
"inorder",
2679-
"postorder",
2680-
"levelorder",
2681-
"breadthfirst",
2682-
"minlex_postorder",
2683-
]
2684-
if tree.num_roots == 1:
2685-
with pytest.raises(ValueError):
2686-
list(t1.nodes(order="bad order"))
2687-
assert list(t1.nodes()) == list(t1.nodes(t1.get_root()))
2688-
assert list(t1.nodes()) == list(t1.nodes(t1.get_root(), "preorder"))
2689-
for u in t1.nodes():
2690-
assert list(t1.nodes(u)) == list(t2.nodes(u))
2691-
for test_order in orders:
2692-
assert sorted(list(t1.nodes())) == sorted(
2693-
list(t1.nodes(order=test_order))
2694-
)
2695-
assert list(t1.nodes(order=test_order)) == list(
2696-
t1.nodes(t1.get_root(), order=test_order)
2697-
)
2698-
assert list(t1.nodes(order=test_order)) == list(
2699-
t1.nodes(t1.get_root(), test_order)
2700-
)
2701-
assert list(t1.nodes(order=test_order)) == list(
2702-
t2.nodes(order=test_order)
2703-
)
2704-
for u in t1.nodes():
2705-
assert list(t1.nodes(u, test_order)) == list(
2706-
t2.nodes(u, test_order)
2707-
)
2708-
else:
2709-
for test_order in orders:
2710-
all_nodes = []
2711-
for root in t1.roots:
2712-
assert list(t1.nodes(root, order=test_order)) == list(
2713-
t2.nodes(root, order=test_order)
2714-
)
2715-
all_nodes.extend(t1.nodes(root, order=test_order))
2716-
# minlex_postorder reorders the roots, so this last test is
2717-
# not appropriate
2718-
if test_order != "minlex_postorder":
2719-
assert all_nodes == list(t1.nodes(order=test_order))
2720-
27212782
def test_total_branch_length(self):
27222783
# Note: this definition works when we have no non-sample branches.
27232784
t1 = self.get_tree()

0 commit comments

Comments
 (0)