|
58 | 58 | from tskit import UNKNOWN_TIME
|
59 | 59 |
|
60 | 60 |
|
| 61 | +def traversal_preorder(tree, root=None): |
| 62 | + roots = tree.roots if root is None else [root] |
| 63 | + for node in roots: |
| 64 | + yield node |
| 65 | + for child in tree.children(node): |
| 66 | + yield from traversal_preorder(tree, child) |
| 67 | + |
| 68 | + |
| 69 | +def traversal_postorder(tree, root=None): |
| 70 | + roots = tree.roots if root is None else [root] |
| 71 | + for node in roots: |
| 72 | + for child in tree.children(node): |
| 73 | + yield from traversal_postorder(tree, child) |
| 74 | + yield node |
| 75 | + |
| 76 | + |
| 77 | +def traversal_inorder(tree, root=None): |
| 78 | + roots = tree.roots if root is None else [root] |
| 79 | + for node in roots: |
| 80 | + children = list(tree.children(node)) |
| 81 | + half = len(children) // 2 |
| 82 | + for child in children[:half]: |
| 83 | + yield from traversal_inorder(tree, child) |
| 84 | + yield node |
| 85 | + for child in children[half:]: |
| 86 | + yield from traversal_inorder(tree, child) |
| 87 | + |
| 88 | + |
| 89 | +def traversal_levelorder(tree, root=None): |
| 90 | + roots = tree.roots if root is None else [root] |
| 91 | + for node in roots: |
| 92 | + # Sorted is a stable ordering so nodes are in tree order within |
| 93 | + # the level |
| 94 | + yield from sorted(list(tree.nodes(node)), key=lambda u: tree.depth(u)) |
| 95 | + |
| 96 | + |
| 97 | +def _traversal_minlex_postorder(tree, u): |
| 98 | + """ |
| 99 | + For a given input ID u, this function returns a tuple whose first value |
| 100 | + is the minimum leaf node ID under node u, and whose second value is |
| 101 | + a list containing the minlex postorder for the subtree rooted at node u. |
| 102 | + The first value is needed for sorting, and the second value is what |
| 103 | + finally gets returned. |
| 104 | + """ |
| 105 | + children = tree.children(u) |
| 106 | + if len(children) > 0: |
| 107 | + children_return = [_traversal_minlex_postorder(tree, c) for c in children] |
| 108 | + # sorts by first value, which is the minimum leaf node ID |
| 109 | + children_return.sort(key=lambda x: x[0]) |
| 110 | + minlex_postorder = [] |
| 111 | + for _, child_minlex_postorder in children_return: |
| 112 | + minlex_postorder.extend(child_minlex_postorder) |
| 113 | + minlex_postorder.extend([u]) |
| 114 | + return (children_return[0][0], minlex_postorder) |
| 115 | + else: |
| 116 | + return (u, [u]) |
| 117 | + |
| 118 | + |
| 119 | +def traversal_minlex_postorder(tree, root=None): |
| 120 | + roots = tree.roots if root is None else [root] |
| 121 | + root_lists = [_traversal_minlex_postorder(tree, node) for node in roots] |
| 122 | + for _, node_list in sorted(root_lists, key=lambda x: x[0]): |
| 123 | + yield from node_list |
| 124 | + |
| 125 | + |
| 126 | +def traversal_timeasc(tree, root=None): |
| 127 | + # This is wrong, it should be |
| 128 | + # nodes = sorted(tree.nodes(root), key=lambda u: (tree.time(u), u)) |
| 129 | + # https://github.com/tskit-dev/tskit/issues/1776 |
| 130 | + roots = tree.roots if root is None else [root] |
| 131 | + for node in roots: |
| 132 | + nodes = sorted(tree.nodes(node), key=lambda u: (tree.time(u), u)) |
| 133 | + yield from nodes |
| 134 | + |
| 135 | + |
| 136 | +def traversal_timedesc(tree, root=None): |
| 137 | + # This is wrong, it should be |
| 138 | + # nodes = sorted(tree.nodes(root), key=lambda u: (-tree.time(u), u)) |
| 139 | + # https://github.com/tskit-dev/tskit/issues/1776 |
| 140 | + roots = tree.roots if root is None else [root] |
| 141 | + for node in roots: |
| 142 | + nodes = sorted(tree.nodes(node), key=lambda u: (tree.time(u), u), reverse=True) |
| 143 | + yield from nodes |
| 144 | + |
| 145 | + |
| 146 | +traversal_map = { |
| 147 | + "preorder": traversal_preorder, |
| 148 | + "postorder": traversal_postorder, |
| 149 | + "inorder": traversal_inorder, |
| 150 | + "levelorder": traversal_levelorder, |
| 151 | + "breadthfirst": traversal_levelorder, |
| 152 | + "minlex_postorder": traversal_minlex_postorder, |
| 153 | + "timeasc": traversal_timeasc, |
| 154 | + "timedesc": traversal_timedesc, |
| 155 | +} |
| 156 | + |
| 157 | + |
61 | 158 | def insert_uniform_mutations(tables, num_mutations, nodes):
|
62 | 159 | """
|
63 | 160 | Returns n evenly mutations over the specified list of nodes.
|
@@ -363,6 +460,39 @@ def get_samples(ts, time=None, population=None):
|
363 | 460 | return np.array(samples)
|
364 | 461 |
|
365 | 462 |
|
| 463 | +class TestTreeTraversals: |
| 464 | + @pytest.mark.parametrize("ts", get_example_tree_sequences()) |
| 465 | + @pytest.mark.parametrize( |
| 466 | + "order", |
| 467 | + [ |
| 468 | + "preorder", |
| 469 | + # "postorder", |
| 470 | + # "inorder", |
| 471 | + # "levelorder", |
| 472 | + # "breadthfirst", |
| 473 | + # "minlex_postorder", |
| 474 | + ], |
| 475 | + ) |
| 476 | + def test_traversals_virtual_root(self, ts, order): |
| 477 | + tree = ts.first() |
| 478 | + node_list2 = list(traversal_map[order](tree, tree.virtual_root)) |
| 479 | + node_list1 = list(tree.nodes(tree.virtual_root, order=order)) |
| 480 | + assert tree.virtual_root in node_list1 |
| 481 | + assert node_list1 == node_list2 |
| 482 | + |
| 483 | + @pytest.mark.parametrize("ts", get_example_tree_sequences()) |
| 484 | + @pytest.mark.parametrize("order", list(traversal_map.keys())) |
| 485 | + def test_traversals(self, ts, order): |
| 486 | + tree = next(ts.trees()) |
| 487 | + traverser = traversal_map[order] |
| 488 | + node_list1 = list(tree.nodes(order=order)) |
| 489 | + node_list2 = list(traverser(tree)) |
| 490 | + assert node_list1 == node_list2 |
| 491 | + |
| 492 | + # TODO add tests for traversing from specific nodes and a couple |
| 493 | + # of hand crafted examples just for sanity checking. |
| 494 | + |
| 495 | + |
366 | 496 | class TestMRCACalculator:
|
367 | 497 | """
|
368 | 498 | Class to test the Schieber-Vishkin algorithm.
|
@@ -2649,75 +2779,6 @@ def verify_nx_nearest_neighbor_search(self):
|
2649 | 2779 | nearest_neighbor_of = [min(dist_dod[u], key=dist_dod[u].get) for u in range(3)]
|
2650 | 2780 | assert [2, 2, 1] == [nearest_neighbor_of[u] for u in range(3)]
|
2651 | 2781 |
|
2652 |
| - def test_traversals(self): |
2653 |
| - for ts in get_example_tree_sequences(): |
2654 |
| - tree = next(ts.trees()) |
2655 |
| - self.verify_traversals(tree) |
2656 |
| - |
2657 |
| - # Verify time-ordered traversals separately, because the PythonTree |
2658 |
| - # class does not contain time information at the moment |
2659 |
| - for root in tree.roots: |
2660 |
| - time_ordered = tree.nodes(root, order="timeasc") |
2661 |
| - t = tree.time(next(time_ordered)) |
2662 |
| - for u in time_ordered: |
2663 |
| - next_t = tree.time(u) |
2664 |
| - assert next_t >= t |
2665 |
| - t = next_t |
2666 |
| - time_ordered = tree.nodes(root, order="timedesc") |
2667 |
| - t = tree.time(next(time_ordered)) |
2668 |
| - for u in time_ordered: |
2669 |
| - next_t = tree.time(u) |
2670 |
| - assert next_t <= t |
2671 |
| - t = next_t |
2672 |
| - |
2673 |
| - def verify_traversals(self, tree): |
2674 |
| - t1 = tree |
2675 |
| - t2 = tests.PythonTree.from_tree(t1) |
2676 |
| - assert list(t1.nodes()) == list(t2.nodes()) |
2677 |
| - orders = [ |
2678 |
| - "inorder", |
2679 |
| - "postorder", |
2680 |
| - "levelorder", |
2681 |
| - "breadthfirst", |
2682 |
| - "minlex_postorder", |
2683 |
| - ] |
2684 |
| - if tree.num_roots == 1: |
2685 |
| - with pytest.raises(ValueError): |
2686 |
| - list(t1.nodes(order="bad order")) |
2687 |
| - assert list(t1.nodes()) == list(t1.nodes(t1.get_root())) |
2688 |
| - assert list(t1.nodes()) == list(t1.nodes(t1.get_root(), "preorder")) |
2689 |
| - for u in t1.nodes(): |
2690 |
| - assert list(t1.nodes(u)) == list(t2.nodes(u)) |
2691 |
| - for test_order in orders: |
2692 |
| - assert sorted(list(t1.nodes())) == sorted( |
2693 |
| - list(t1.nodes(order=test_order)) |
2694 |
| - ) |
2695 |
| - assert list(t1.nodes(order=test_order)) == list( |
2696 |
| - t1.nodes(t1.get_root(), order=test_order) |
2697 |
| - ) |
2698 |
| - assert list(t1.nodes(order=test_order)) == list( |
2699 |
| - t1.nodes(t1.get_root(), test_order) |
2700 |
| - ) |
2701 |
| - assert list(t1.nodes(order=test_order)) == list( |
2702 |
| - t2.nodes(order=test_order) |
2703 |
| - ) |
2704 |
| - for u in t1.nodes(): |
2705 |
| - assert list(t1.nodes(u, test_order)) == list( |
2706 |
| - t2.nodes(u, test_order) |
2707 |
| - ) |
2708 |
| - else: |
2709 |
| - for test_order in orders: |
2710 |
| - all_nodes = [] |
2711 |
| - for root in t1.roots: |
2712 |
| - assert list(t1.nodes(root, order=test_order)) == list( |
2713 |
| - t2.nodes(root, order=test_order) |
2714 |
| - ) |
2715 |
| - all_nodes.extend(t1.nodes(root, order=test_order)) |
2716 |
| - # minlex_postorder reorders the roots, so this last test is |
2717 |
| - # not appropriate |
2718 |
| - if test_order != "minlex_postorder": |
2719 |
| - assert all_nodes == list(t1.nodes(order=test_order)) |
2720 |
| - |
2721 | 2782 | def test_total_branch_length(self):
|
2722 | 2783 | # Note: this definition works when we have no non-sample branches.
|
2723 | 2784 | t1 = self.get_tree()
|
|
0 commit comments