Skip to content

Commit 857f1c6

Browse files
authored
Merge pull request #143 from kayjan/enhance-pruning
Enhance pruning
2 parents 40178b2 + 66419a5 commit 857f1c6

File tree

8 files changed

+204
-19
lines changed

8 files changed

+204
-19
lines changed

bigtree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
tree_to_dot,
2929
tree_to_mermaid,
3030
tree_to_nested_dict,
31+
tree_to_newick,
3132
tree_to_pillow,
3233
yield_tree,
3334
)

bigtree/tree/export.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"tree_to_dot",
4040
"tree_to_pillow",
4141
"tree_to_mermaid",
42+
"tree_to_newick",
4243
]
4344

4445
T = TypeVar("T", bound=Node)
@@ -1156,3 +1157,36 @@ def get_attr(
11561157
flows="\n".join(flows),
11571158
styles="\n".join(styles),
11581159
)
1160+
1161+
1162+
def tree_to_newick(tree: T) -> str:
1163+
"""Export tree to Newick notation.
1164+
1165+
>>> from bigtree import Node, tree_to_newick
1166+
>>> root = Node("a", age=90)
1167+
>>> b = Node("b", age=65, parent=root)
1168+
>>> c = Node("c", age=60, parent=root)
1169+
>>> d = Node("d", age=40, parent=b)
1170+
>>> e = Node("e", age=35, parent=b)
1171+
>>> root.show()
1172+
a
1173+
├── b
1174+
│ ├── d
1175+
│ └── e
1176+
└── c
1177+
1178+
>>> tree_to_newick(root)
1179+
'((d,e)b,c)a'
1180+
1181+
Args:
1182+
tree (Node): tree to be exported
1183+
1184+
Returns:
1185+
(str)
1186+
"""
1187+
if not tree:
1188+
return ""
1189+
if tree.is_leaf:
1190+
return tree.node_name
1191+
children_newick = ",".join(tree_to_newick(child) for child in tree.children)
1192+
return f"({children_newick}){tree.node_name}"

bigtree/tree/helper.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from typing import Any, Deque, Dict, List, Type, TypeVar, Union
2+
from typing import Any, Deque, Dict, List, Set, Type, TypeVar, Union
33

44
from bigtree.node.basenode import BaseNode
55
from bigtree.node.binarynode import BinaryNode
@@ -56,18 +56,21 @@ def recursive_add_child(
5656

5757
def prune_tree(
5858
tree: Union[BinaryNodeT, NodeT],
59-
prune_path: str = "",
59+
prune_path: Union[List[str], str] = "",
60+
exact: bool = False,
6061
sep: str = "/",
6162
max_depth: int = 0,
6263
) -> Union[BinaryNodeT, NodeT]:
6364
"""Prune tree by path or depth, returns the root of a *copy* of the original tree.
6465
6566
For pruning by `prune_path`,
66-
All siblings along the prune path will be removed.
67-
Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
67+
- All siblings along the prune path will be removed.
68+
- If ``exact=True``, all descendants of prune path will be removed.
69+
- Prune path can be string (only one path) or a list of strings (multiple paths).
70+
- Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
6871
6972
For pruning by `max_depth`,
70-
All nodes that are beyond `max_depth` will be removed.
73+
- All nodes that are beyond `max_depth` will be removed.
7174
7275
Path should contain ``Node`` name, separated by `sep`.
7376
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
@@ -85,13 +88,33 @@ def prune_tree(
8588
│ └── d
8689
└── e
8790
91+
Prune (default is keep descendants)
92+
8893
>>> root_pruned = prune_tree(root, "a/b")
8994
>>> root_pruned.show()
9095
a
9196
└── b
9297
├── c
9398
└── d
9499
100+
Prune exact path
101+
102+
>>> root_pruned = prune_tree(root, "a/b", exact=True)
103+
>>> root_pruned.show()
104+
a
105+
└── b
106+
107+
Prune multiple paths
108+
109+
>>> root_pruned = prune_tree(root, ["a/b/d", "a/e"])
110+
>>> root_pruned.show()
111+
a
112+
├── b
113+
│ └── d
114+
└── e
115+
116+
Prune by depth
117+
95118
>>> root_pruned = prune_tree(root, max_depth=2)
96119
>>> root_pruned.show()
97120
a
@@ -100,31 +123,47 @@ def prune_tree(
100123
101124
Args:
102125
tree (Union[BinaryNode, Node]): existing tree
103-
prune_path (str): prune path, all siblings along the prune path will be removed
126+
prune_path (List[str] | str): prune path(s), all siblings along the prune path(s) will be removed
127+
exact (bool): prune path(s) to be exactly the path, defaults to False (descendants of the path are retained)
104128
sep (str): path separator of `prune_path`
105129
max_depth (int): maximum depth of pruned tree, based on `depth` attribute, defaults to None
106130
107131
Returns:
108132
(Union[BinaryNode, Node])
109133
"""
110-
if not prune_path and not max_depth:
134+
if isinstance(prune_path, str):
135+
prune_path = [prune_path] if prune_path else []
136+
137+
if not len(prune_path) and not max_depth:
111138
raise ValueError("Please specify either `prune_path` or `max_depth` or both.")
112139

113140
tree_copy = tree.copy()
114141

115142
# Prune by path (prune bottom-up)
116-
if prune_path:
117-
prune_path = prune_path.replace(sep, tree.sep)
118-
child = find_path(tree_copy, prune_path)
119-
if not child:
120-
raise NotFoundError(
121-
f"Cannot find any node matching path_name ending with {prune_path}"
122-
)
123-
while child.parent:
124-
for other_children in child.parent.children:
125-
if other_children != child:
126-
other_children.parent = None
127-
child = child.parent
143+
if len(prune_path):
144+
ancestors_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
145+
nodes_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
146+
for path in prune_path:
147+
path = path.replace(sep, tree.sep)
148+
child = find_path(tree_copy, path)
149+
if not child:
150+
raise NotFoundError(
151+
f"Cannot find any node matching path_name ending with {path}"
152+
)
153+
nodes_to_prune.add(child)
154+
ancestors_to_prune.update(list(child.ancestors))
155+
156+
if exact:
157+
ancestors_to_prune.update(nodes_to_prune)
158+
159+
for node in ancestors_to_prune:
160+
for child in node.children:
161+
if (
162+
child
163+
and child not in ancestors_to_prune
164+
and child not in nodes_to_prune
165+
):
166+
child.parent = None
128167

129168
# Prune by depth (prune top-down)
130169
if max_depth:

docs/source/bigtree/tree/export.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Export Tree to list, dictionary, and pandas DataFrame.
1111
- Method
1212
* - Command Line / Others
1313
- `print_tree`, `yield_tree`
14+
* - String
15+
- `tree_to_newick`
1416
* - Dictionary
1517
- `tree_to_dict`, `tree_to_nested_dict`
1618
* - DataFrame
@@ -46,6 +48,12 @@ While exporting to another data type, methods can take in arguments to determine
4648
- No, but can specify subtree
4749
- No
4850
- Tree style
51+
* - `tree_to_newick`
52+
- No
53+
- No
54+
- No
55+
- No
56+
- N/A
4957
* - `tree_to_dict`
5058
- Yes with `attr_dict` or `all_attrs`
5159
- Yes with `max_depth`

tests/binarytree/test_export.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
tree_to_dict,
77
tree_to_dot,
88
tree_to_nested_dict,
9+
tree_to_newick,
910
)
1011
from tests.conftest import assert_print_statement
1112
from tests.test_constants import Constants
@@ -89,3 +90,11 @@ def test_tree_to_dot(binarytree_node):
8990
assert (
9091
expected_str in actual
9192
), f"Expected {expected_str} not in actual string"
93+
94+
95+
class TestTreeToNewick:
96+
@staticmethod
97+
def test_tree_to_newick(binarytree_node):
98+
newick_str = tree_to_newick(binarytree_node)
99+
expected_str = """(((,8)4,5)2,(6,7)3)1"""
100+
assert newick_str == expected_str

tests/binarytree/test_helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ def test_prune_tree(binarytree_node):
3030
None,
3131
), "Check children (depth 5)"
3232

33+
@staticmethod
34+
def test_prune_tree_exact(binarytree_node):
35+
# Pruned tree is 1/2/4
36+
tree_prune = prune_tree(binarytree_node, "1/2/4", exact=True)
37+
38+
assert tree_prune.left.val == 2, "Check left child of tree_prune (depth 2)"
39+
assert not tree_prune.right, "Check right child of tree_prune (depth 2)"
40+
assert tree_prune.left.left.val == 4, "Check left child (depth 3)"
41+
assert not tree_prune.left.right, "Check right child (depth 3)"
42+
assert tree_prune.left.left.children == (
43+
None,
44+
None,
45+
), "Check children (depth 4)"
46+
3347
@staticmethod
3448
def test_prune_tree_second_child(binarytree_node):
3549
# Pruned tree is 1/3/6, 1/3/7
@@ -48,6 +62,24 @@ def test_prune_tree_second_child(binarytree_node):
4862
None,
4963
), "Check children (depth 4)"
5064

65+
@staticmethod
66+
def test_prune_tree_list(binarytree_node):
67+
# Pruned tree is 1/3/6, 1/3/7
68+
tree_prune = prune_tree(binarytree_node, ["1/3/6", "1/3/7"])
69+
70+
assert not tree_prune.left, "Check left child of tree_prune (depth 2)"
71+
assert tree_prune.right.val == 3, "Check right child of tree_prune (depth 2)"
72+
assert tree_prune.right.left.val == 6, "Check left child (depth 3)"
73+
assert tree_prune.right.right.val == 7, "Check right child (depth 3)"
74+
assert tree_prune.right.left.children == (
75+
None,
76+
None,
77+
), "Check children (depth 4)"
78+
assert tree_prune.right.right.children == (
79+
None,
80+
None,
81+
), "Check children (depth 4)"
82+
5183

5284
class TestTreeDiff:
5385
@staticmethod

tests/tree/test_export.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
tree_to_dot,
1010
tree_to_mermaid,
1111
tree_to_nested_dict,
12+
tree_to_newick,
1213
tree_to_pillow,
1314
)
1415
from tests.conftest import assert_print_statement
@@ -1376,3 +1377,11 @@ def get_node_attr(node):
13761377
expected_str = """```mermaid\n%%{ init: { \'flowchart\': { \'curve\': \'basis\' } } }%%\nflowchart TB\n0("a") --> 0-0("b")\n0-0 --> 0-0-0("d")\n0-0-0:::class0-0-0-0 --> 0-0-0-0("g")\n0-0 --> 0-0-1("e")\n0-0-1:::class0-0-1-0 --> 0-0-1-0("h")\n0("a") --> 0-1("c")\n0-1 --> 0-1-0("f")\nclassDef default stroke-width:1\nclassDef class0-0-0-0 fill:red,stroke:black,stroke-width:2\nclassDef class0-0-1-0 fill:red,stroke:black,stroke-width:2\n```"""
13771378

13781379
assert mermaid_md == expected_str
1380+
1381+
1382+
class TestTreeToNewick:
1383+
@staticmethod
1384+
def test_tree_to_newick(tree_node):
1385+
newick_str = tree_to_newick(tree_node)
1386+
expected_str = """((d,(g,h)e)b,(f)c)a"""
1387+
assert newick_str == expected_str

tests/tree/test_helper.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ def test_prune_tree(tree_node):
6262
assert len(tree_prune.children[0].children[0].children) == 0
6363
assert len(tree_prune.children[0].children[1].children) == 2
6464

65+
@staticmethod
66+
def test_prune_tree_exact(tree_node):
67+
# Pruned tree is a/b/e/g
68+
tree_prune = prune_tree(tree_node, "a/b/e/g", exact=True)
69+
70+
assert_tree_structure_basenode_root(tree_node)
71+
assert_tree_structure_basenode_root_attr(tree_node)
72+
assert len(list(tree_prune.children)) == 1
73+
assert len(tree_prune.children[0].children) == 1
74+
assert len(tree_prune.children[0].children[0].children) == 1
75+
76+
@staticmethod
77+
def test_prune_tree_list(tree_node):
78+
# Pruned tree is a/b/d, a/b/e/g, a/b/e/h
79+
tree_prune = prune_tree(tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"])
80+
81+
assert_tree_structure_basenode_root(tree_node)
82+
assert_tree_structure_basenode_root_attr(tree_node)
83+
assert len(list(tree_prune.children)) == 1
84+
assert len(tree_prune.children[0].children) == 2
85+
assert len(tree_prune.children[0].children[0].children) == 0
86+
assert len(tree_prune.children[0].children[1].children) == 2
87+
6588
@staticmethod
6689
def test_prune_tree_path_and_depth(tree_node):
6790
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
@@ -76,6 +99,36 @@ def test_prune_tree_path_and_depth(tree_node):
7699
len(tree_prune.children[0].children[1].children) == 0
77100
), "Depth at 4 is not pruned away"
78101

102+
@staticmethod
103+
def test_prune_tree_path_and_depth_and_list(tree_node):
104+
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
105+
tree_prune = prune_tree(tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"], max_depth=3)
106+
107+
assert_tree_structure_basenode_root(tree_node)
108+
assert_tree_structure_basenode_root_attr(tree_node)
109+
assert len(list(tree_prune.children)) == 1
110+
assert len(tree_prune.children[0].children) == 2
111+
assert len(tree_prune.children[0].children[0].children) == 0
112+
assert (
113+
len(tree_prune.children[0].children[1].children) == 0
114+
), "Depth at 4 is not pruned away"
115+
116+
@staticmethod
117+
def test_prune_tree_path_and_depth_and_list_and_exact(tree_node):
118+
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
119+
tree_prune = prune_tree(
120+
tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"], exact=True, max_depth=3
121+
)
122+
123+
assert_tree_structure_basenode_root(tree_node)
124+
assert_tree_structure_basenode_root_attr(tree_node)
125+
assert len(list(tree_prune.children)) == 1
126+
assert len(tree_prune.children[0].children) == 2
127+
assert len(tree_prune.children[0].children[0].children) == 0
128+
assert (
129+
len(tree_prune.children[0].children[1].children) == 0
130+
), "Depth at 4 is not pruned away"
131+
79132
@staticmethod
80133
def test_prune_tree_only_depth(tree_node):
81134
# Pruned tree is a/b, a/c (a/b/e/g, a/b/e/h, a/c/f pruned away)

0 commit comments

Comments
 (0)